Skip to content

Commit 3a776a7

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Allow adapter_size=32 for supervised tuning
PiperOrigin-RevId: 786356415
1 parent 89ce1ae commit 3a776a7

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

vertexai/preview/tuning/_supervised_tuning.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def train(
3939
epochs: Optional[int] = None,
4040
learning_rate: Optional[float] = None,
4141
learning_rate_multiplier: Optional[float] = None,
42-
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
42+
adapter_size: Optional[Literal[1, 4, 8, 16, 32]] = None,
4343
labels: Optional[Dict[str, str]] = None,
4444
output_uri: Optional[str] = None,
4545
) -> "SupervisedTuningJob":
@@ -103,9 +103,13 @@ def train(
103103
adapter_size_value = (
104104
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_SIXTEEN
105105
)
106+
elif adapter_size == 32:
107+
adapter_size_value = (
108+
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_THIRTY_TWO
109+
)
106110
else:
107111
raise ValueError(
108-
f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16]"
112+
f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16, 32]"
109113
)
110114
if isinstance(train_dataset, datasets.MultimodalDataset):
111115
train_dataset = train_dataset.resource_name

vertexai/tuning/_supervised_tuning.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def train(
3232
tuned_model_display_name: Optional[str] = None,
3333
epochs: Optional[int] = None,
3434
learning_rate_multiplier: Optional[float] = None,
35-
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
35+
adapter_size: Optional[Literal[1, 4, 8, 16, 32]] = None,
3636
labels: Optional[Dict[str, str]] = None,
3737
) -> "SupervisedTuningJob":
3838
"""Tunes a model using supervised training.
@@ -70,9 +70,13 @@ def train(
7070
adapter_size_value = (
7171
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_SIXTEEN
7272
)
73+
elif adapter_size == 32:
74+
adapter_size_value = (
75+
gca_tuning_job_types.SupervisedHyperParameters.AdapterSize.ADAPTER_SIZE_THIRTY_TWO
76+
)
7377
else:
7478
raise ValueError(
75-
f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16]"
79+
f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16, 32]"
7680
)
7781
if isinstance(train_dataset, datasets.MultimodalDataset):
7882
train_dataset = train_dataset.resource_name

0 commit comments

Comments
 (0)