Skip to content

Commit 35ecbac

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Added GA support for running Custom and Hp tuning jobs on Persistent Resources
PiperOrigin-RevId: 620975901
1 parent 6aaa5d0 commit 35ecbac

File tree

4 files changed

+60
-51
lines changed

4 files changed

+60
-51
lines changed

google/cloud/aiplatform/jobs.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,7 @@ def __init__(
17311731
labels: Optional[Dict[str, str]] = None,
17321732
encryption_spec_key_name: Optional[str] = None,
17331733
staging_bucket: Optional[str] = None,
1734+
persistent_resource_id: Optional[str] = None,
17341735
):
17351736
"""Constructs a Custom Job with Worker Pool Specs.
17361737
@@ -1802,6 +1803,13 @@ def __init__(
18021803
staging_bucket (str):
18031804
Optional. Bucket for produced custom job artifacts. Overrides
18041805
staging_bucket set in aiplatform.init.
1806+
persistent_resource_id (str):
1807+
Optional. The ID of the PersistentResource in the same Project
1808+
and Location. If this is specified, the job will be run on
1809+
existing machines held by the PersistentResource instead of
1810+
on-demand short-live machines. The network and CMEK configs on
1811+
the job should be consistent with those on the PersistentResource,
1812+
otherwise, the job will be rejected.
18051813
18061814
Raises:
18071815
RuntimeError: If staging bucket was not set using aiplatform.init
@@ -1836,6 +1844,7 @@ def __init__(
18361844
base_output_directory=gca_io_compat.GcsDestination(
18371845
output_uri_prefix=base_output_dir
18381846
),
1847+
persistent_resource_id=persistent_resource_id,
18391848
),
18401849
labels=labels,
18411850
encryption_spec=initializer.global_config.get_encryption_spec(
@@ -2669,7 +2678,8 @@ def __init__(
26692678
of any UTF-8 characters.
26702679
custom_job (aiplatform.CustomJob):
26712680
Required. Configured CustomJob. The worker pool spec from this custom job
2672-
applies to the CustomJobs created in all the trials.
2681+
applies to the CustomJobs created in all the trials. A persistent_resource_id can be
2682+
specified on the custom job to be used when running this Hyperparameter Tuning job.
26732683
metric_spec: Dict[str, str]
26742684
Required. Dictionary representing metrics to optimize. The dictionary key is the metric_id,
26752685
which is reported by your training job, and the dictionary value is the

google/cloud/aiplatform/preview/jobs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171

7272
class CustomJob(jobs.CustomJob):
73-
"""Vertex AI Custom Job."""
73+
"""Deprecated. Vertex AI Custom Job (preview)."""
7474

7575
def __init__(
7676
self,
@@ -88,7 +88,9 @@ def __init__(
8888
staging_bucket: Optional[str] = None,
8989
persistent_resource_id: Optional[str] = None,
9090
):
91-
"""Constructs a Custom Job with Worker Pool Specs.
91+
"""Deprecated. Please use the GA (non-preview) version of this class.
92+
93+
Constructs a Custom Job with Worker Pool Specs.
9294
9395
```
9496
Example usage:
@@ -472,7 +474,7 @@ def submit(
472474

473475

474476
class HyperparameterTuningJob(jobs.HyperparameterTuningJob):
475-
"""Vertex AI Hyperparameter Tuning Job."""
477+
"""Deprecated. Vertex AI Hyperparameter Tuning Job (preview)."""
476478

477479
def __init__(
478480
self,
@@ -492,7 +494,8 @@ def __init__(
492494
labels: Optional[Dict[str, str]] = None,
493495
encryption_spec_key_name: Optional[str] = None,
494496
):
495-
"""
497+
"""Deprecated. Please use the GA (non-preview) version of this class.
498+
496499
Configures a HyperparameterTuning Job.
497500
498501
Example usage:

tests/unit/aiplatform/test_custom_job_persistent_resource.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,12 @@
2020
from unittest.mock import patch
2121

2222
from google.cloud import aiplatform
23-
from google.cloud.aiplatform.compat.services import (
24-
job_service_client_v1beta1,
25-
)
26-
from google.cloud.aiplatform.compat.types import custom_job_v1beta1
27-
from google.cloud.aiplatform.compat.types import encryption_spec_v1beta1
28-
from google.cloud.aiplatform.compat.types import io_v1beta1
29-
from google.cloud.aiplatform.compat.types import (
30-
job_state_v1beta1 as gca_job_state_compat,
31-
)
32-
from google.cloud.aiplatform.preview import jobs
23+
from google.cloud.aiplatform import jobs
24+
from google.cloud.aiplatform.compat.services import job_service_client_v1
25+
from google.cloud.aiplatform.compat.types import custom_job_v1
26+
from google.cloud.aiplatform.compat.types import encryption_spec_v1
27+
from google.cloud.aiplatform.compat.types import io_v1
28+
from google.cloud.aiplatform.compat.types import job_state_v1 as gca_job_state_compat
3329
import constants as test_constants
3430
import pytest
3531

@@ -58,7 +54,7 @@
5854

5955
# CMEK encryption
6056
_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_1234"
61-
_TEST_DEFAULT_ENCRYPTION_SPEC = encryption_spec_v1beta1.EncryptionSpec(
57+
_TEST_DEFAULT_ENCRYPTION_SPEC = encryption_spec_v1.EncryptionSpec(
6258
kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME
6359
)
6460

@@ -78,14 +74,14 @@
7874

7975
# Persistent Resource
8076
_TEST_PERSISTENT_RESOURCE_ID = "test-persistent-resource-1"
81-
_TEST_CUSTOM_JOB_WITH_PERSISTENT_RESOURCE_PROTO = custom_job_v1beta1.CustomJob(
77+
_TEST_CUSTOM_JOB_WITH_PERSISTENT_RESOURCE_PROTO = custom_job_v1.CustomJob(
8278
display_name=_TEST_DISPLAY_NAME,
83-
job_spec=custom_job_v1beta1.CustomJobSpec(
79+
job_spec=custom_job_v1.CustomJobSpec(
8480
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
85-
base_output_directory=io_v1beta1.GcsDestination(
81+
base_output_directory=io_v1.GcsDestination(
8682
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
8783
),
88-
scheduling=custom_job_v1beta1.Scheduling(
84+
scheduling=custom_job_v1.Scheduling(
8985
timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
9086
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
9187
disable_retries=_TEST_DISABLE_RETRIES,
@@ -108,21 +104,21 @@ def _get_custom_job_proto(state=None, name=None, error=None):
108104

109105

110106
@pytest.fixture
111-
def create_preview_custom_job_mock():
107+
def create_custom_job_mock():
112108
with mock.patch.object(
113-
job_service_client_v1beta1.JobServiceClient, "create_custom_job"
114-
) as create_preview_custom_job_mock:
115-
create_preview_custom_job_mock.return_value = _get_custom_job_proto(
109+
job_service_client_v1.JobServiceClient, "create_custom_job"
110+
) as create_custom_job_mock:
111+
create_custom_job_mock.return_value = _get_custom_job_proto(
116112
name=_TEST_CUSTOM_JOB_NAME,
117113
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
118114
)
119-
yield create_preview_custom_job_mock
115+
yield create_custom_job_mock
120116

121117

122118
@pytest.fixture
123119
def get_custom_job_mock():
124120
with patch.object(
125-
job_service_client_v1beta1.JobServiceClient, "get_custom_job"
121+
job_service_client_v1.JobServiceClient, "get_custom_job"
126122
) as get_custom_job_mock:
127123
get_custom_job_mock.side_effect = [
128124
_get_custom_job_proto(
@@ -152,7 +148,7 @@ def teardown_method(self):
152148

153149
@pytest.mark.parametrize("sync", [True, False])
154150
def test_create_custom_job_with_persistent_resource(
155-
self, create_preview_custom_job_mock, get_custom_job_mock, sync
151+
self, create_custom_job_mock, get_custom_job_mock, sync
156152
):
157153

158154
aiplatform.init(
@@ -188,7 +184,7 @@ def test_create_custom_job_with_persistent_resource(
188184

189185
expected_custom_job = _get_custom_job_proto()
190186

191-
create_preview_custom_job_mock.assert_called_once_with(
187+
create_custom_job_mock.assert_called_once_with(
192188
parent=_TEST_PARENT,
193189
custom_job=expected_custom_job,
194190
timeout=None,
@@ -201,7 +197,7 @@ def test_create_custom_job_with_persistent_resource(
201197
assert job.network == _TEST_NETWORK
202198

203199
def test_submit_custom_job_with_persistent_resource(
204-
self, create_preview_custom_job_mock, get_custom_job_mock
200+
self, create_custom_job_mock, get_custom_job_mock
205201
):
206202

207203
aiplatform.init(
@@ -236,7 +232,7 @@ def test_submit_custom_job_with_persistent_resource(
236232

237233
expected_custom_job = _get_custom_job_proto()
238234

239-
create_preview_custom_job_mock.assert_called_once_with(
235+
create_custom_job_mock.assert_called_once_with(
240236
parent=_TEST_PARENT,
241237
custom_job=expected_custom_job,
242238
timeout=None,

tests/unit/aiplatform/test_hyperparameter_tuning_job_persistent_resource.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@
2121

2222
from google.cloud import aiplatform
2323
from google.cloud.aiplatform.compat.services import (
24-
job_service_client_v1beta1,
24+
job_service_client_v1,
2525
)
2626
from google.cloud.aiplatform import hyperparameter_tuning as hpt
2727
from google.cloud.aiplatform.compat.types import (
28-
custom_job_v1beta1,
29-
encryption_spec_v1beta1,
30-
hyperparameter_tuning_job_v1beta1,
31-
io_v1beta1,
32-
job_state_v1beta1 as gca_job_state_compat,
33-
study_v1beta1 as gca_study_compat,
28+
custom_job_v1,
29+
encryption_spec_v1,
30+
hyperparameter_tuning_job_v1,
31+
io_v1,
32+
job_state_v1 as gca_job_state_compat,
33+
study_v1 as gca_study_compat,
3434
)
35-
from google.cloud.aiplatform.preview import jobs
35+
from google.cloud.aiplatform import jobs
3636
import constants as test_constants
3737
import pytest
3838

@@ -59,7 +59,7 @@
5959

6060
# CMEK encryption
6161
_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_1234"
62-
_TEST_DEFAULT_ENCRYPTION_SPEC = encryption_spec_v1beta1.EncryptionSpec(
62+
_TEST_DEFAULT_ENCRYPTION_SPEC = encryption_spec_v1.EncryptionSpec(
6363
kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME
6464
)
6565

@@ -95,12 +95,12 @@
9595
# Persistent Resource
9696
_TEST_PERSISTENT_RESOURCE_ID = "test-persistent-resource-1"
9797

98-
_TEST_TRIAL_JOB_SPEC = custom_job_v1beta1.CustomJobSpec(
98+
_TEST_TRIAL_JOB_SPEC = custom_job_v1.CustomJobSpec(
9999
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
100-
base_output_directory=io_v1beta1.GcsDestination(
100+
base_output_directory=io_v1.GcsDestination(
101101
output_uri_prefix=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR
102102
),
103-
scheduling=custom_job_v1beta1.Scheduling(
103+
scheduling=custom_job_v1.Scheduling(
104104
timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
105105
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
106106
disable_retries=_TEST_DISABLE_RETRIES,
@@ -110,7 +110,7 @@
110110
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
111111
)
112112

113-
_TEST_BASE_HYPERPARAMETER_TUNING_JOB_WITH_PERSISTENT_RESOURCE_PROTO = hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob(
113+
_TEST_BASE_HYPERPARAMETER_TUNING_JOB_WITH_PERSISTENT_RESOURCE_PROTO = hyperparameter_tuning_job_v1.HyperparameterTuningJob(
114114
display_name=_TEST_DISPLAY_NAME,
115115
study_spec=gca_study_compat.StudySpec(
116116
metrics=[
@@ -197,23 +197,23 @@ def _get_hyperparameter_tuning_job_proto(state=None, name=None, error=None):
197197

198198

199199
@pytest.fixture
200-
def create_preview_hyperparameter_tuning_job_mock():
200+
def create_hyperparameter_tuning_job_mock():
201201
with mock.patch.object(
202-
job_service_client_v1beta1.JobServiceClient, "create_hyperparameter_tuning_job"
203-
) as create_preview_hyperparameter_tuning_job_mock:
204-
create_preview_hyperparameter_tuning_job_mock.return_value = (
202+
job_service_client_v1.JobServiceClient, "create_hyperparameter_tuning_job"
203+
) as create_hyperparameter_tuning_job_mock:
204+
create_hyperparameter_tuning_job_mock.return_value = (
205205
_get_hyperparameter_tuning_job_proto(
206206
name=_TEST_HYPERPARAMETERTUNING_JOB_NAME,
207207
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
208208
)
209209
)
210-
yield create_preview_hyperparameter_tuning_job_mock
210+
yield create_hyperparameter_tuning_job_mock
211211

212212

213213
@pytest.fixture
214214
def get_hyperparameter_tuning_job_mock():
215215
with patch.object(
216-
job_service_client_v1beta1.JobServiceClient, "get_hyperparameter_tuning_job"
216+
job_service_client_v1.JobServiceClient, "get_hyperparameter_tuning_job"
217217
) as get_hyperparameter_tuning_job_mock:
218218
get_hyperparameter_tuning_job_mock.side_effect = [
219219
_get_hyperparameter_tuning_job_proto(
@@ -248,7 +248,7 @@ def teardown_method(self):
248248
@pytest.mark.parametrize("sync", [True, False])
249249
def test_create_hyperparameter_tuning_job_with_persistent_resource(
250250
self,
251-
create_preview_hyperparameter_tuning_job_mock,
251+
create_hyperparameter_tuning_job_mock,
252252
get_hyperparameter_tuning_job_mock,
253253
sync,
254254
):
@@ -308,7 +308,7 @@ def test_create_hyperparameter_tuning_job_with_persistent_resource(
308308

309309
expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto()
310310

311-
create_preview_hyperparameter_tuning_job_mock.assert_called_once_with(
311+
create_hyperparameter_tuning_job_mock.assert_called_once_with(
312312
parent=_TEST_PARENT,
313313
hyperparameter_tuning_job=expected_hyperparameter_tuning_job,
314314
timeout=None,

0 commit comments

Comments
 (0)