Skip to content

Commit 56273f7

Browse files
authored
feat: Add VPC Peering support to CustomTrainingJob classes (#378)
* Add 'network' for VPC Peering in custom training * Blacken code
1 parent 7eaedb6 commit 56273f7

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir(
15261526
worker_pool_specs: _DistributedTrainingSpec,
15271527
base_output_dir: Optional[str] = None,
15281528
service_account: Optional[str] = None,
1529+
network: Optional[str] = None,
15291530
) -> Tuple[Dict, str]:
15301531
"""Prepares training task inputs and output directory for custom job.
15311532
@@ -1538,6 +1539,11 @@ def _prepare_training_task_inputs_and_output_dir(
15381539
service_account (str):
15391540
Specifies the service account for workload run-as account.
15401541
Users submitting jobs must have act-as permission on this run-as account.
1542+
network (str):
1543+
The full name of the Compute Engine network to which the job
1544+
should be peered. For example, projects/12345/global/networks/myVPC.
1545+
Private services access must already be configured for the network.
1546+
If left unspecified, the job is not peered with any network.
15411547
Returns:
15421548
Training task inputs and Output directory for custom job.
15431549
"""
@@ -1556,6 +1562,8 @@ def _prepare_training_task_inputs_and_output_dir(
15561562

15571563
if service_account:
15581564
training_task_inputs["serviceAccount"] = service_account
1565+
if network:
1566+
training_task_inputs["network"] = network
15591567

15601568
return training_task_inputs, base_output_dir
15611569

@@ -1803,6 +1811,7 @@ def run(
18031811
model_display_name: Optional[str] = None,
18041812
base_output_dir: Optional[str] = None,
18051813
service_account: Optional[str] = None,
1814+
network: Optional[str] = None,
18061815
bigquery_destination: Optional[str] = None,
18071816
args: Optional[List[Union[str, float, int]]] = None,
18081817
environment_variables: Optional[Dict[str, str]] = None,
@@ -1891,6 +1900,11 @@ def run(
18911900
service_account (str):
18921901
Specifies the service account for workload run-as account.
18931902
Users submitting jobs must have act-as permission on this run-as account.
1903+
network (str):
1904+
The full name of the Compute Engine network to which the job
1905+
should be peered. For example, projects/12345/global/networks/myVPC.
1906+
Private services access must already be configured for the network.
1907+
If left unspecified, the job is not peered with any network.
18941908
bigquery_destination (str):
18951909
Provide this field if `dataset` is a BiqQuery dataset.
18961910
The BigQuery project location where the training data is to
@@ -1981,6 +1995,7 @@ def run(
19811995
environment_variables=environment_variables,
19821996
base_output_dir=base_output_dir,
19831997
service_account=service_account,
1998+
network=network,
19841999
bigquery_destination=bigquery_destination,
19852000
training_fraction_split=training_fraction_split,
19862001
validation_fraction_split=validation_fraction_split,
@@ -2008,6 +2023,7 @@ def _run(
20082023
environment_variables: Optional[Dict[str, str]] = None,
20092024
base_output_dir: Optional[str] = None,
20102025
service_account: Optional[str] = None,
2026+
network: Optional[str] = None,
20112027
bigquery_destination: Optional[str] = None,
20122028
training_fraction_split: float = 0.8,
20132029
validation_fraction_split: float = 0.1,
@@ -2061,6 +2077,11 @@ def _run(
20612077
service_account (str):
20622078
Specifies the service account for workload run-as account.
20632079
Users submitting jobs must have act-as permission on this run-as account.
2080+
network (str):
2081+
The full name of the Compute Engine network to which the job
2082+
should be peered. For example, projects/12345/global/networks/myVPC.
2083+
Private services access must already be configured for the network.
2084+
If left unspecified, the job is not peered with any network.
20642085
bigquery_destination (str):
20652086
Provide this field if `dataset` is a BiqQuery dataset.
20662087
The BigQuery project location where the training data is to
@@ -2130,7 +2151,10 @@ def _run(
21302151
training_task_inputs,
21312152
base_output_dir,
21322153
) = self._prepare_training_task_inputs_and_output_dir(
2133-
worker_pool_specs, base_output_dir, service_account
2154+
worker_pool_specs=worker_pool_specs,
2155+
base_output_dir=base_output_dir,
2156+
service_account=service_account,
2157+
network=network,
21342158
)
21352159

21362160
model = self._run_job(
@@ -2375,6 +2399,7 @@ def run(
23752399
model_display_name: Optional[str] = None,
23762400
base_output_dir: Optional[str] = None,
23772401
service_account: Optional[str] = None,
2402+
network: Optional[str] = None,
23782403
bigquery_destination: Optional[str] = None,
23792404
args: Optional[List[Union[str, float, int]]] = None,
23802405
environment_variables: Optional[Dict[str, str]] = None,
@@ -2456,6 +2481,11 @@ def run(
24562481
service_account (str):
24572482
Specifies the service account for workload run-as account.
24582483
Users submitting jobs must have act-as permission on this run-as account.
2484+
network (str):
2485+
The full name of the Compute Engine network to which the job
2486+
should be peered. For example, projects/12345/global/networks/myVPC.
2487+
Private services access must already be configured for the network.
2488+
If left unspecified, the job is not peered with any network.
24592489
bigquery_destination (str):
24602490
Provide this field if `dataset` is a BiqQuery dataset.
24612491
The BigQuery project location where the training data is to
@@ -2545,6 +2575,7 @@ def run(
25452575
environment_variables=environment_variables,
25462576
base_output_dir=base_output_dir,
25472577
service_account=service_account,
2578+
network=network,
25482579
bigquery_destination=bigquery_destination,
25492580
training_fraction_split=training_fraction_split,
25502581
validation_fraction_split=validation_fraction_split,
@@ -2571,6 +2602,7 @@ def _run(
25712602
environment_variables: Optional[Dict[str, str]] = None,
25722603
base_output_dir: Optional[str] = None,
25732604
service_account: Optional[str] = None,
2605+
network: Optional[str] = None,
25742606
bigquery_destination: Optional[str] = None,
25752607
training_fraction_split: float = 0.8,
25762608
validation_fraction_split: float = 0.1,
@@ -2621,6 +2653,11 @@ def _run(
26212653
service_account (str):
26222654
Specifies the service account for workload run-as account.
26232655
Users submitting jobs must have act-as permission on this run-as account.
2656+
network (str):
2657+
The full name of the Compute Engine network to which the job
2658+
should be peered. For example, projects/12345/global/networks/myVPC.
2659+
Private services access must already be configured for the network.
2660+
If left unspecified, the job is not peered with any network.
26242661
bigquery_destination (str):
26252662
The BigQuery project location where the training data is to
26262663
be written to. In the given project a new dataset is created
@@ -2683,7 +2720,10 @@ def _run(
26832720
training_task_inputs,
26842721
base_output_dir,
26852722
) = self._prepare_training_task_inputs_and_output_dir(
2686-
worker_pool_specs, base_output_dir, service_account
2723+
worker_pool_specs=worker_pool_specs,
2724+
base_output_dir=base_output_dir,
2725+
service_account=service_account,
2726+
network=network,
26872727
)
26882728

26892729
model = self._run_job(
@@ -3709,6 +3749,7 @@ def run(
37093749
model_display_name: Optional[str] = None,
37103750
base_output_dir: Optional[str] = None,
37113751
service_account: Optional[str] = None,
3752+
network: Optional[str] = None,
37123753
bigquery_destination: Optional[str] = None,
37133754
args: Optional[List[Union[str, float, int]]] = None,
37143755
environment_variables: Optional[Dict[str, str]] = None,
@@ -3790,6 +3831,11 @@ def run(
37903831
service_account (str):
37913832
Specifies the service account for workload run-as account.
37923833
Users submitting jobs must have act-as permission on this run-as account.
3834+
network (str):
3835+
The full name of the Compute Engine network to which the job
3836+
should be peered. For example, projects/12345/global/networks/myVPC.
3837+
Private services access must already be configured for the network.
3838+
If left unspecified, the job is not peered with any network.
37933839
bigquery_destination (str):
37943840
Provide this field if `dataset` is a BiqQuery dataset.
37953841
The BigQuery project location where the training data is to
@@ -3874,6 +3920,7 @@ def run(
38743920
environment_variables=environment_variables,
38753921
base_output_dir=base_output_dir,
38763922
service_account=service_account,
3923+
network=network,
38773924
training_fraction_split=training_fraction_split,
38783925
validation_fraction_split=validation_fraction_split,
38793926
test_fraction_split=test_fraction_split,
@@ -3900,6 +3947,7 @@ def _run(
39003947
environment_variables: Optional[Dict[str, str]] = None,
39013948
base_output_dir: Optional[str] = None,
39023949
service_account: Optional[str] = None,
3950+
network: Optional[str] = None,
39033951
training_fraction_split: float = 0.8,
39043952
validation_fraction_split: float = 0.1,
39053953
test_fraction_split: float = 0.1,
@@ -3951,6 +3999,11 @@ def _run(
39513999
service_account (str):
39524000
Specifies the service account for workload run-as account.
39534001
Users submitting jobs must have act-as permission on this run-as account.
4002+
network (str):
4003+
The full name of the Compute Engine network to which the job
4004+
should be peered. For example, projects/12345/global/networks/myVPC.
4005+
Private services access must already be configured for the network.
4006+
If left unspecified, the job is not peered with any network.
39544007
training_fraction_split (float):
39554008
The fraction of the input data that is to be
39564009
used to train the Model.
@@ -3999,7 +4052,10 @@ def _run(
39994052
training_task_inputs,
40004053
base_output_dir,
40014054
) = self._prepare_training_task_inputs_and_output_dir(
4002-
worker_pool_specs, base_output_dir, service_account
4055+
worker_pool_specs=worker_pool_specs,
4056+
base_output_dir=base_output_dir,
4057+
service_account=service_account,
4058+
network=network,
40034059
)
40044060

40054061
model = self._run_job(

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
)
110110
_TEST_ALT_PROJECT = "test-project-alt"
111111
_TEST_ALT_LOCATION = "europe-west4"
112+
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}"
112113

113114
_TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml"
114115
_TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml"
@@ -598,6 +599,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
598599
dataset=mock_tabular_dataset,
599600
base_output_dir=_TEST_BASE_OUTPUT_DIR,
600601
service_account=_TEST_SERVICE_ACCOUNT,
602+
network=_TEST_NETWORK,
601603
args=_TEST_RUN_ARGS,
602604
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
603605
replica_count=1,
@@ -700,6 +702,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
700702
"workerPoolSpecs": [true_worker_pool_spec],
701703
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
702704
"serviceAccount": _TEST_SERVICE_ACCOUNT,
705+
"network": _TEST_NETWORK,
703706
},
704707
struct_pb2.Value(),
705708
),
@@ -2539,6 +2542,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
25392542
annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI,
25402543
base_output_dir=_TEST_BASE_OUTPUT_DIR,
25412544
service_account=_TEST_SERVICE_ACCOUNT,
2545+
network=_TEST_NETWORK,
25422546
args=_TEST_RUN_ARGS,
25432547
replica_count=1,
25442548
machine_type=_TEST_MACHINE_TYPE,
@@ -2621,6 +2625,7 @@ def test_run_call_pipeline_service_create_with_nontabular_dataset(
26212625
"workerPoolSpecs": [true_worker_pool_spec],
26222626
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
26232627
"serviceAccount": _TEST_SERVICE_ACCOUNT,
2628+
"network": _TEST_NETWORK,
26242629
},
26252630
struct_pb2.Value(),
26262631
),
@@ -2970,6 +2975,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
29702975
model_display_name=_TEST_MODEL_DISPLAY_NAME,
29712976
base_output_dir=_TEST_BASE_OUTPUT_DIR,
29722977
service_account=_TEST_SERVICE_ACCOUNT,
2978+
network=_TEST_NETWORK,
29732979
args=_TEST_RUN_ARGS,
29742980
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
29752981
replica_count=1,
@@ -3065,6 +3071,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
30653071
"workerPoolSpecs": [true_worker_pool_spec],
30663072
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
30673073
"serviceAccount": _TEST_SERVICE_ACCOUNT,
3074+
"network": _TEST_NETWORK,
30683075
},
30693076
struct_pb2.Value(),
30703077
),

0 commit comments

Comments
 (0)