@@ -39,7 +39,7 @@ def train(
39
39
epochs : Optional [int ] = None ,
40
40
learning_rate : Optional [float ] = None ,
41
41
learning_rate_multiplier : Optional [float ] = None ,
42
- adapter_size : Optional [Literal [1 , 4 , 8 , 16 ]] = None ,
42
+ adapter_size : Optional [Literal [1 , 4 , 8 , 16 , 32 ]] = None ,
43
43
labels : Optional [Dict [str , str ]] = None ,
44
44
output_uri : Optional [str ] = None ,
45
45
) -> "SupervisedTuningJob" :
@@ -103,9 +103,13 @@ def train(
103
103
adapter_size_value = (
104
104
gca_tuning_job_types .SupervisedHyperParameters .AdapterSize .ADAPTER_SIZE_SIXTEEN
105
105
)
106
+ elif adapter_size == 32 :
107
+ adapter_size_value = (
108
+ gca_tuning_job_types .SupervisedHyperParameters .AdapterSize .ADAPTER_SIZE_THIRTY_TWO
109
+ )
106
110
else :
107
111
raise ValueError (
108
- f"Unsupported adapter size: { adapter_size } . The supported sizes are [1, 4, 8, 16]"
112
+ f"Unsupported adapter size: { adapter_size } . The supported sizes are [1, 4, 8, 16, 32 ]"
109
113
)
110
114
if isinstance (train_dataset , datasets .MultimodalDataset ):
111
115
train_dataset = train_dataset .resource_name
0 commit comments