diff --git a/.kokoro/continuous/system.cfg b/.kokoro/continuous/system.cfg
index bda5345ece..f5ed200a12 100644
--- a/.kokoro/continuous/system.cfg
+++ b/.kokoro/continuous/system.cfg
@@ -8,7 +8,7 @@ env_vars: {
# Run system tests in parallel, splitting up by file
env_vars: {
key: "PYTEST_ADDOPTS"
- value: "-n=auto --dist=loadscope"
+ value: "-n=16 --dist=loadscope"
}
# Kokoro VM timeout of 12 hours for system tests
diff --git a/.kokoro/samples/python3.10/common.cfg b/.kokoro/samples/python3.10/common.cfg
index a49138fd0a..17abca148a 100644
--- a/.kokoro/samples/python3.10/common.cfg
+++ b/.kokoro/samples/python3.10/common.cfg
@@ -14,7 +14,7 @@ env_vars: {
}
# Declare build specific Cloud project.
-env_vars: {
+env_vars: {
key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
value: "ucaip-sample-tests"
}
diff --git a/.kokoro/samples/python3.11/common.cfg b/.kokoro/samples/python3.11/common.cfg
index c870d5b2c7..1166f2c317 100644
--- a/.kokoro/samples/python3.11/common.cfg
+++ b/.kokoro/samples/python3.11/common.cfg
@@ -14,7 +14,7 @@ env_vars: {
}
# Declare build specific Cloud project.
-env_vars: {
+env_vars: {
key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
value: "ucaip-sample-tests"
}
diff --git a/.kokoro/samples/python3.7/common.cfg b/.kokoro/samples/python3.7/common.cfg
index cc8296c89d..eed23ad9bc 100644
--- a/.kokoro/samples/python3.7/common.cfg
+++ b/.kokoro/samples/python3.7/common.cfg
@@ -14,7 +14,7 @@ env_vars: {
}
# Declare build specific Cloud project.
-env_vars: {
+env_vars: {
key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
value: "ucaip-sample-tests"
}
diff --git a/.kokoro/samples/python3.8/common.cfg b/.kokoro/samples/python3.8/common.cfg
index a118253a82..26e513823e 100644
--- a/.kokoro/samples/python3.8/common.cfg
+++ b/.kokoro/samples/python3.8/common.cfg
@@ -14,7 +14,7 @@ env_vars: {
}
# Declare build specific Cloud project.
-env_vars: {
+env_vars: {
key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
value: "ucaip-sample-tests"
}
diff --git a/.kokoro/samples/python3.9/common.cfg b/.kokoro/samples/python3.9/common.cfg
index 5a549c80fc..abda08ed27 100644
--- a/.kokoro/samples/python3.9/common.cfg
+++ b/.kokoro/samples/python3.9/common.cfg
@@ -14,7 +14,7 @@ env_vars: {
}
# Declare build specific Cloud project.
-env_vars: {
+env_vars: {
key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
value: "ucaip-sample-tests"
}
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index d0ed7ddc2d..ddf991e2a5 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "1.30.1"
+ ".": "1.31.0"
}
diff --git a/CHANGELOG.md b/CHANGELOG.md
index df6b00277a..2316321130 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,35 @@
# Changelog
+## [1.31.0](https://github.com/googleapis/python-aiplatform/compare/v1.30.1...v1.31.0) (2023-08-21)
+
+
+### Features
+
+* Add disable_retries option to custom jobs. ([db518b0](https://github.com/googleapis/python-aiplatform/commit/db518b0552a8900ca6a84a73ca711b775c786e92))
+* LLM - Added support for `stop_sequences` in inference ([6f7ea84](https://github.com/googleapis/python-aiplatform/commit/6f7ea84415e5d0efcc49487c93b0f1d94fd68974))
+* LLM - Exposed the `TextGenerationResponse.raw_prediction_response` ([f8f2b9c](https://github.com/googleapis/python-aiplatform/commit/f8f2b9cdf88f40fe0b7e86948515ab1cf72d92be))
+* LLM - Made tuning asynchronous when tuning becomes GA ([226ab8b](https://github.com/googleapis/python-aiplatform/commit/226ab8b64efc01d7ce20cdf924e103d7673376cf))
+* LLM - release model evaluation for TextGenerationModel to public preview ([8df5185](https://github.com/googleapis/python-aiplatform/commit/8df5185d668292d5adc11ebf9477e2fdd44599d4))
+* LLM - Released `TextGenerationModel` tuning to GA ([62ff30d](https://github.com/googleapis/python-aiplatform/commit/62ff30daa718ac7869714c68e55d6955d6355945))
+* LLM - Support streaming prediction for chat models ([ce60cf7](https://github.com/googleapis/python-aiplatform/commit/ce60cf75ec5c83db8033b553e1ad7164159fb3be))
+* LLM - Support streaming prediction for code chat models ([0359f1d](https://github.com/googleapis/python-aiplatform/commit/0359f1dd83bf86df58d1145ddf5e4634d3c8e1ff))
+* LLM - Support streaming prediction for code generation models ([3a8348b](https://github.com/googleapis/python-aiplatform/commit/3a8348bca2d9c74e5e52fb9fc131fdb766f49a5c))
+* LLM - Support streaming prediction for text generation models ([fb527f3](https://github.com/googleapis/python-aiplatform/commit/fb527f3aa59ee90fa6306196b328f513ee4b4d9c))
+* LLM - TextEmbeddingModel - Added support for structural inputs (`TextEmbeddingInput`), `auto_truncate` parameter and result `statistics` ([cbf9b6e](https://github.com/googleapis/python-aiplatform/commit/cbf9b6ee806d7eb89725f53c4509858a272b3141))
+* LVM - Added support for Image Generation models ([b3729c1](https://github.com/googleapis/python-aiplatform/commit/b3729c11a70abaf061daa56ed4c483c4118d5acf))
+* LVM - Released `ImageCaptioningModel` to GA ([7575046](https://github.com/googleapis/python-aiplatform/commit/7575046d953e83bbb8aa13769f28e1eb50e04a7d))
+* LVM - Released `ImageQnAModel` to GA ([fd5cb02](https://github.com/googleapis/python-aiplatform/commit/fd5cb0226f4cff7ee160d2005c5907b81f847a1e))
+* LVM - Released `MultiModalEmbeddingModel` to GA ([e99f366](https://github.com/googleapis/python-aiplatform/commit/e99f366fde802b8677b785613e02fc4d9f94d729))
+* LVM - Removed the `width` and `height` parameters from `ImageGenerationModel.generate_images` since the service has dropped support for image sizes and aspect ratios ([52897e6](https://github.com/googleapis/python-aiplatform/commit/52897e669ff91d3bb991fcf05ae9a18df93df05f))
+* Scheduled pipelines client GA. ([62b8b23](https://github.com/googleapis/python-aiplatform/commit/62b8b23e1144ec547b8d181240090b744dd5201a))
+
+
+### Documentation
+
+* Generate documentation for tune_model and related class ([705e1ea](https://github.com/googleapis/python-aiplatform/commit/705e1ea402684f3ff4a4cf1f80c04b88bf6cf7db))
+* LVM - Added autogenerated documentation for visual models ([18e8bb2](https://github.com/googleapis/python-aiplatform/commit/18e8bb283e80fa9efb26f5fe3f8997b0b038bb12))
+
## [1.30.1](https://github.com/googleapis/python-aiplatform/compare/v1.30.0...v1.30.1) (2023-08-11)
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 953a2fc715..4334f9fb0c 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -22,7 +22,7 @@ In order to add a feature:
documentation.
- The feature must work fully on the following CPython versions:
- 3.7, 3.8, 3.9 and 3.10 on both UNIX and Windows.
+ 3.7, 3.8, 3.9, 3.10 and 3.11 on both UNIX and Windows.
- The feature must not add unnecessary dependencies (where
"unnecessary" is of course subjective, but new dependencies should
@@ -72,7 +72,7 @@ We use `nox `__ to instrument our tests.
- To run a single unit test::
- $ nox -s unit-3.10 -- -k
+ $ nox -s unit-3.11 -- -k
.. note::
@@ -225,11 +225,13 @@ We support:
- `Python 3.8`_
- `Python 3.9`_
- `Python 3.10`_
+- `Python 3.11`_
.. _Python 3.7: https://docs.python.org/3.7/
.. _Python 3.8: https://docs.python.org/3.8/
.. _Python 3.9: https://docs.python.org/3.9/
.. _Python 3.10: https://docs.python.org/3.10/
+.. _Python 3.11: https://docs.python.org/3.11/
Supported versions can be found in our ``noxfile.py`` `config`_.
diff --git a/docs/vertexai/services.rst b/docs/vertexai/services.rst
index 12e47a0ffb..72ab574d86 100644
--- a/docs/vertexai/services.rst
+++ b/docs/vertexai/services.rst
@@ -11,6 +11,10 @@ Vertex AI SDK
:show-inheritance:
:inherited-members:
+.. automodule:: vertexai.language_models._language_models
+ :no-members:
+ :private-members: _TunableModelMixin
+
.. automodule:: vertexai.preview
:members:
:show-inheritance:
@@ -20,3 +24,13 @@ Vertex AI SDK
:members:
:show-inheritance:
:inherited-members:
+
+.. automodule:: vertexai.vision_models
+ :members:
+ :show-inheritance:
+ :inherited-members:
+
+.. automodule:: vertexai.preview.vision_models
+ :members:
+ :show-inheritance:
+ :inherited-members:
diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py
index c265005120..6435efaff5 100644
--- a/google/cloud/aiplatform/__init__.py
+++ b/google/cloud/aiplatform/__init__.py
@@ -56,6 +56,9 @@
ModelDeploymentMonitoringJob,
)
from google.cloud.aiplatform.pipeline_jobs import PipelineJob
+from google.cloud.aiplatform.pipeline_job_schedules import (
+ PipelineJobSchedule,
+)
from google.cloud.aiplatform.tensorboard import (
Tensorboard,
TensorboardExperiment,
@@ -167,6 +170,7 @@
"ModelEvaluation",
"ModelDeploymentMonitoringJob",
"PipelineJob",
+ "PipelineJobSchedule",
"PrivateEndpoint",
"RandomSampleConfig",
"SequenceToSequencePlusForecastingTrainingJob",
diff --git a/google/cloud/aiplatform/_streaming_prediction.py b/google/cloud/aiplatform/_streaming_prediction.py
new file mode 100644
index 0000000000..cd5a33a491
--- /dev/null
+++ b/google/cloud/aiplatform/_streaming_prediction.py
@@ -0,0 +1,166 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Streaming prediction functions."""
+
+from typing import Any, Dict, Iterator, List, Optional, Sequence
+
+from google.cloud.aiplatform_v1.services import prediction_service
+from google.cloud.aiplatform_v1.types import (
+ prediction_service as prediction_service_types,
+)
+from google.cloud.aiplatform_v1.types import (
+ types as aiplatform_types,
+)
+
+
+def value_to_tensor(value: Any) -> aiplatform_types.Tensor:
+ """Converts a Python value to `Tensor`.
+
+ Args:
+ value: A value to convert
+
+ Returns:
+ A `Tensor` object
+ """
+ if value is None:
+ return aiplatform_types.Tensor()
+ elif isinstance(value, int):
+ return aiplatform_types.Tensor(int_val=[value])
+ elif isinstance(value, float):
+ return aiplatform_types.Tensor(float_val=[value])
+ elif isinstance(value, bool):
+ return aiplatform_types.Tensor(bool_val=[value])
+ elif isinstance(value, str):
+ return aiplatform_types.Tensor(string_val=[value])
+ elif isinstance(value, bytes):
+ return aiplatform_types.Tensor(bytes_val=[value])
+ elif isinstance(value, list):
+ return aiplatform_types.Tensor(list_val=[value_to_tensor(x) for x in value])
+ elif isinstance(value, dict):
+ return aiplatform_types.Tensor(
+ struct_val={k: value_to_tensor(v) for k, v in value.items()}
+ )
+ raise TypeError(f"Unsupported value type {type(value)}")
+
+
+def tensor_to_value(tensor_pb: aiplatform_types.Tensor) -> Any:
+ """Converts `Tensor` to a Python value.
+
+ Args:
+ tensor_pb: A `Tensor` object
+
+ Returns:
+ A corresponding Python object
+ """
+ list_of_fields = tensor_pb.ListFields()
+ if not list_of_fields:
+ return None
+ descriptor, value = tensor_pb.ListFields()[0]
+ if descriptor.name == "list_val":
+ return [tensor_to_value(x) for x in value]
+ elif descriptor.name == "struct_val":
+ return {k: tensor_to_value(v) for k, v in value.items()}
+ if not isinstance(value, Sequence):
+ raise TypeError(f"Unexpected non-list tensor value {value}")
+ if len(value) == 1:
+ return value[0]
+ else:
+ return value
+
+
+def predict_stream_of_tensor_lists_from_single_tensor_list(
+ prediction_service_client: prediction_service.PredictionServiceClient,
+ endpoint_name: str,
+ tensor_list: List[aiplatform_types.Tensor],
+ parameters_tensor: Optional[aiplatform_types.Tensor] = None,
+) -> Iterator[List[aiplatform_types.Tensor]]:
+ """Predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
+
+ Args:
+ tensor_list: Model input as a list of `Tensor` objects.
+ parameters_tensor: Optional. Prediction parameters in `Tensor` form.
+ prediction_service_client: A PredictionServiceClient object.
+ endpoint_name: Resource name of Endpoint or PublisherModel.
+
+ Yields:
+ A generator of model prediction `Tensor` lists.
+ """
+ request = prediction_service_types.StreamingPredictRequest(
+ endpoint=endpoint_name,
+ inputs=tensor_list,
+ parameters=parameters_tensor,
+ )
+ for response in prediction_service_client.server_streaming_predict(request=request):
+ yield response.outputs
+
+
+def predict_stream_of_dict_lists_from_single_dict_list(
+ prediction_service_client: prediction_service.PredictionServiceClient,
+ endpoint_name: str,
+ dict_list: List[Dict[str, Any]],
+ parameters: Optional[Dict[str, Any]] = None,
+) -> Iterator[List[Dict[str, Any]]]:
+ """Predicts a stream of lists of dicts from a stream of lists of dicts.
+
+ Args:
+ dict_list: Model input as a list of `dict` objects.
+ parameters: Optional. Prediction parameters `dict` form.
+ prediction_service_client: A PredictionServiceClient object.
+ endpoint_name: Resource name of Endpoint or PublisherModel.
+
+ Yields:
+ A generator of model prediction dict lists.
+ """
+ tensor_list = [value_to_tensor(d) for d in dict_list]
+ parameters_tensor = value_to_tensor(parameters) if parameters else None
+ for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list(
+ prediction_service_client=prediction_service_client,
+ endpoint_name=endpoint_name,
+ tensor_list=tensor_list,
+ parameters_tensor=parameters_tensor,
+ ):
+ yield [tensor_to_value(tensor._pb) for tensor in tensor_list]
+
+
+def predict_stream_of_dicts_from_single_dict(
+ prediction_service_client: prediction_service.PredictionServiceClient,
+ endpoint_name: str,
+ instance: Dict[str, Any],
+ parameters: Optional[Dict[str, Any]] = None,
+) -> Iterator[Dict[str, Any]]:
+ """Predicts a stream of dicts from a single instance dict.
+
+ Args:
+ instance: A single input instance `dict`.
+ parameters: Optional. Prediction parameters `dict`.
+ prediction_service_client: A PredictionServiceClient object.
+ endpoint_name: Resource name of Endpoint or PublisherModel.
+
+ Yields:
+ A generator of model prediction dicts.
+ """
+ for dict_list in predict_stream_of_dict_lists_from_single_dict_list(
+ prediction_service_client=prediction_service_client,
+ endpoint_name=endpoint_name,
+ dict_list=[instance],
+ parameters=parameters,
+ ):
+ if len(dict_list) > 1:
+ raise ValueError(
+ f"Expected to receive a single output, but got {dict_list}"
+ )
+ yield dict_list[0]
diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py
index c78f61285e..c1a278c64f 100644
--- a/google/cloud/aiplatform/base.py
+++ b/google/cloud/aiplatform/base.py
@@ -69,15 +69,6 @@ def __init__(self, name: str):
self._logger = logging.getLogger(name)
self._logger.setLevel(logging.INFO)
- if self._logger.handlers:
- # Avoid writing duplicate logs if the logger is created twice.
- return
-
- handler = logging.StreamHandler(sys.stdout)
- handler.setLevel(logging.INFO)
-
- self._logger.addHandler(handler)
-
def log_create_with_lro(
self,
cls: Type["VertexAiResourceNoun"],
@@ -200,6 +191,10 @@ def __getattr__(self, attr: str):
_LOGGER = Logger(__name__)
+if not _LOGGER._logger.handlers:
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setLevel(logging.INFO)
+ _LOGGER._logger.addHandler(handler)
class FutureManager(metaclass=abc.ABCMeta):
diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py
index abf601b881..f4d7548025 100644
--- a/google/cloud/aiplatform/compat/__init__.py
+++ b/google/cloud/aiplatform/compat/__init__.py
@@ -144,6 +144,7 @@
services.model_service_client = services.model_service_client_v1
services.pipeline_service_client = services.pipeline_service_client_v1
services.prediction_service_client = services.prediction_service_client_v1
+ services.schedule_service_client = services.schedule_service_client_v1
services.specialist_pool_service_client = services.specialist_pool_service_client_v1
services.tensorboard_service_client = services.tensorboard_service_client_v1
services.index_service_client = services.index_service_client_v1
@@ -208,6 +209,8 @@
types.pipeline_state = types.pipeline_state_v1
types.prediction_service = types.prediction_service_v1
types.publisher_model = types.publisher_model_v1
+ types.schedule = types.schedule_v1
+ types.schedule_service = types.schedule_service_v1
types.specialist_pool = types.specialist_pool_v1
types.specialist_pool_service = types.specialist_pool_service_v1
types.study = types.study_v1
diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py
index d1bd98b59a..d2e464425b 100644
--- a/google/cloud/aiplatform/compat/services/__init__.py
+++ b/google/cloud/aiplatform/compat/services/__init__.py
@@ -54,6 +54,9 @@
from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
client as pipeline_service_client_v1beta1,
)
+from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
+ client as persistent_resource_service_client_v1beta1,
+)
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
client as prediction_service_client_v1beta1,
)
@@ -106,6 +109,9 @@
from google.cloud.aiplatform_v1.services.prediction_service import (
client as prediction_service_client_v1,
)
+from google.cloud.aiplatform_v1.services.schedule_service import (
+ client as schedule_service_client_v1,
+)
from google.cloud.aiplatform_v1.services.specialist_pool_service import (
client as specialist_pool_service_client_v1,
)
@@ -130,6 +136,7 @@
model_service_client_v1,
pipeline_service_client_v1,
prediction_service_client_v1,
+ schedule_service_client_v1,
specialist_pool_service_client_v1,
tensorboard_service_client_v1,
vizier_service_client_v1,
@@ -145,6 +152,7 @@
match_service_client_v1beta1,
model_garden_service_client_v1beta1,
model_service_client_v1beta1,
+ persistent_resource_service_client_v1beta1,
pipeline_service_client_v1beta1,
prediction_service_client_v1beta1,
schedule_service_client_v1beta1,
diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py
index 0b0bd9e08b..fb72dc7103 100644
--- a/google/cloud/aiplatform/compat/types/__init__.py
+++ b/google/cloud/aiplatform/compat/types/__init__.py
@@ -145,6 +145,8 @@
pipeline_state as pipeline_state_v1,
prediction_service as prediction_service_v1,
publisher_model as publisher_model_v1,
+ schedule as schedule_v1,
+ schedule_service as schedule_service_v1,
specialist_pool as specialist_pool_v1,
specialist_pool_service as specialist_pool_service_v1,
study as study_v1,
@@ -215,6 +217,8 @@
pipeline_state_v1,
prediction_service_v1,
publisher_model_v1,
+ schedule_v1,
+ schedule_service_v1,
specialist_pool_v1,
specialist_pool_service_v1,
tensorboard_v1,
diff --git a/google/cloud/aiplatform/constants/base.py b/google/cloud/aiplatform/constants/base.py
index 8145c847a0..e7be1523ce 100644
--- a/google/cloud/aiplatform/constants/base.py
+++ b/google/cloud/aiplatform/constants/base.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import os
from google.cloud.aiplatform import version as aiplatform_version
@@ -56,8 +56,8 @@
"us-west4",
}
)
-
-API_BASE_PATH = "aiplatform.googleapis.com"
+# This env variable injection is for testing, but not considered to be a public API.
+API_BASE_PATH = os.environ.get("_VERTEX_API_BASE_PATH") or "aiplatform.googleapis.com"
PREDICTION_API_BASE_PATH = API_BASE_PATH
# Batch Prediction
diff --git a/google/cloud/aiplatform/preview/constants/schedules.py b/google/cloud/aiplatform/constants/schedule.py
similarity index 97%
rename from google/cloud/aiplatform/preview/constants/schedules.py
rename to google/cloud/aiplatform/constants/schedule.py
index bdf9d45b65..ce4ab5a852 100644
--- a/google/cloud/aiplatform/preview/constants/schedules.py
+++ b/google/cloud/aiplatform/constants/schedule.py
@@ -16,7 +16,7 @@
#
from google.cloud.aiplatform.compat.types import (
- schedule_v1beta1 as gca_schedule,
+ schedule as gca_schedule,
)
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/gapic_version.py
+++ b/google/cloud/aiplatform/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py
index ef23f86275..90e7a8471f 100644
--- a/google/cloud/aiplatform/jobs.py
+++ b/google/cloud/aiplatform/jobs.py
@@ -1629,6 +1629,7 @@ def run(
tensorboard: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> None:
"""Run this configured CustomJob.
@@ -1686,6 +1687,10 @@ def run(
will unblock and it will be executed in a concurrent Future.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
"""
network = network or initializer.global_config.network
@@ -1700,6 +1705,7 @@ def run(
tensorboard=tensorboard,
sync=sync,
create_request_timeout=create_request_timeout,
+ disable_retries=disable_retries,
)
@base.optional_sync()
@@ -1715,6 +1721,7 @@ def _run(
tensorboard: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> None:
"""Helper method to ensure network synchronization and to run the configured CustomJob.
@@ -1770,6 +1777,10 @@ def _run(
will unblock and it will be executed in a concurrent Future.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
"""
self.submit(
service_account=service_account,
@@ -1781,6 +1792,7 @@ def _run(
experiment_run=experiment_run,
tensorboard=tensorboard,
create_request_timeout=create_request_timeout,
+ disable_retries=disable_retries,
)
self._block_until_complete()
@@ -1797,6 +1809,7 @@ def submit(
experiment_run: Optional[Union["aiplatform.ExperimentRun", str]] = None,
tensorboard: Optional[str] = None,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> None:
"""Submit the configured CustomJob.
@@ -1849,6 +1862,10 @@ def submit(
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Raises:
ValueError:
@@ -1869,11 +1886,12 @@ def submit(
if network:
self._gca_resource.job_spec.network = network
- if timeout or restart_job_on_worker_restart:
+ if timeout or restart_job_on_worker_restart or disable_retries:
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
timeout=timeout,
restart_job_on_worker_restart=restart_job_on_worker_restart,
+ disable_retries=disable_retries,
)
if enable_web_access:
@@ -2287,6 +2305,7 @@ def run(
tensorboard: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> None:
"""Run this configured CustomJob.
@@ -2331,6 +2350,10 @@ def run(
will unblock and it will be executed in a concurrent Future.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
"""
network = network or initializer.global_config.network
@@ -2343,6 +2366,7 @@ def run(
tensorboard=tensorboard,
sync=sync,
create_request_timeout=create_request_timeout,
+ disable_retries=disable_retries,
)
@base.optional_sync()
@@ -2356,6 +2380,7 @@ def _run(
tensorboard: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> None:
"""Helper method to ensure network synchronization and to run the configured CustomJob.
@@ -2398,6 +2423,10 @@ def _run(
will unblock and it will be executed in a concurrent Future.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
"""
if service_account:
self._gca_resource.trial_job_spec.service_account = service_account
@@ -2405,12 +2434,13 @@ def _run(
if network:
self._gca_resource.trial_job_spec.network = network
- if timeout or restart_job_on_worker_restart:
+ if timeout or restart_job_on_worker_restart or disable_retries:
duration = duration_pb2.Duration(seconds=timeout) if timeout else None
self._gca_resource.trial_job_spec.scheduling = (
gca_custom_job_compat.Scheduling(
timeout=duration,
restart_job_on_worker_restart=restart_job_on_worker_restart,
+ disable_retries=disable_retries,
)
)
diff --git a/google/cloud/aiplatform/pipeline_job_schedules.py b/google/cloud/aiplatform/pipeline_job_schedules.py
new file mode 100644
index 0000000000..1e66fd8809
--- /dev/null
+++ b/google/cloud/aiplatform/pipeline_job_schedules.py
@@ -0,0 +1,444 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List, Optional
+
+from google.auth import credentials as auth_credentials
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import (
+ PipelineJob,
+)
+from google.cloud.aiplatform.schedules import (
+ _Schedule,
+)
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.compat.types import (
+ schedule as gca_schedule,
+)
+from google.cloud.aiplatform.constants import (
+ schedule as schedule_constants,
+)
+
+from google.protobuf import field_mask_pb2 as field_mask
+
+_LOGGER = base.Logger(__name__)
+
+# Pattern for valid names used as a Vertex resource name.
+_VALID_NAME_PATTERN = schedule_constants._VALID_NAME_PATTERN
+
+# Pattern for an Artifact Registry URL.
+_VALID_AR_URL = schedule_constants._VALID_AR_URL
+
+# Pattern for any JSON or YAML file over HTTPS.
+_VALID_HTTPS_URL = schedule_constants._VALID_HTTPS_URL
+
+_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES
+
+
+class PipelineJobSchedule(
+ _Schedule,
+):
+ def __init__(
+ self,
+ pipeline_job: PipelineJob,
+ display_name: str,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ ):
+ """Retrieves a PipelineJobSchedule resource and instantiates its
+ representation.
+
+ Args:
+ pipeline_job (PipelineJob):
+ Required. PipelineJob used to init the schedule.
+ display_name (str):
+ Required. The user-defined name of this PipelineJobSchedule.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create this PipelineJobSchedule.
+ Overrides credentials set in aiplatform.init.
+ project (str):
+ Optional. The project that you want to run this PipelineJobSchedule in.
+ If not set, the project set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create PipelineJobSchedule. If not set,
+ location set in aiplatform.init will be used.
+ """
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
+ utils.validate_display_name(display_name)
+
+ super().__init__(credentials=credentials, project=project, location=location)
+
+ self._parent = initializer.global_config.common_location_path(
+ project=project, location=location
+ )
+
+ create_pipeline_job_request = {
+ "parent": self._parent,
+ "pipeline_job": {
+ "runtime_config": pipeline_job.runtime_config,
+ "pipeline_spec": pipeline_job.pipeline_spec,
+ },
+ }
+ if "template_uri" in pipeline_job._gca_resource:
+ create_pipeline_job_request["pipeline_job"][
+ "template_uri"
+ ] = pipeline_job._gca_resource.template_uri
+ pipeline_job_schedule_args = {
+ "display_name": display_name,
+ "create_pipeline_job_request": create_pipeline_job_request,
+ }
+
+ self._gca_resource = gca_schedule.Schedule(**pipeline_job_schedule_args)
+
+ def create(
+ self,
+ cron: str,
+ start_time: Optional[str] = None,
+ end_time: Optional[str] = None,
+ allow_queueing: bool = False,
+ max_run_count: Optional[int] = None,
+ max_concurrent_run_count: int = 1,
+ service_account: Optional[str] = None,
+ network: Optional[str] = None,
+ create_request_timeout: Optional[float] = None,
+ ) -> None:
+ """Create a PipelineJobSchedule.
+
+ Args:
+ cron (str):
+ Required. Time specification (cron schedule expression) to launch scheduled runs.
+ To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}".
+ The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database.
+ For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *".
+ start_time (str):
+ Optional. Timestamp after which the first run can be scheduled.
+ If unspecified, it defaults to the schedule creation timestamp.
+ end_time (str):
+ Optional. Timestamp after which no more runs will be scheduled.
+ If unspecified, then runs will be scheduled indefinitely.
+ allow_queueing (bool):
+ Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached.
+ max_run_count (int):
+ Optional. Maximum run count of the schedule.
+ If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached.
+ Must be positive and <= 2^63-1.
+ max_concurrent_run_count (int):
+ Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule.
+ service_account (str):
+ Optional. Specifies the service account for workload run-as account.
+ Users submitting jobs must have act-as permission on this run-as account.
+ network (str):
+ Optional. The full name of the Compute Engine network to which the job
+ should be peered. For example, projects/12345/global/networks/myVPC.
+ Private services access must already be configured for the network.
+ If left unspecified, the network set in aiplatform.init will be used.
+ Otherwise, the job is not peered with any network.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ """
+ network = network or initializer.global_config.network
+
+ self._create(
+ cron=cron,
+ start_time=start_time,
+ end_time=end_time,
+ allow_queueing=allow_queueing,
+ max_run_count=max_run_count,
+ max_concurrent_run_count=max_concurrent_run_count,
+ service_account=service_account,
+ network=network,
+ create_request_timeout=create_request_timeout,
+ )
+
+ def _create(
+ self,
+ cron: str,
+ start_time: Optional[str] = None,
+ end_time: Optional[str] = None,
+ allow_queueing: bool = False,
+ max_run_count: Optional[int] = None,
+ max_concurrent_run_count: int = 1,
+ service_account: Optional[str] = None,
+ network: Optional[str] = None,
+ create_request_timeout: Optional[float] = None,
+ ) -> None:
+ """Helper method to create the PipelineJobSchedule.
+
+ Args:
+ cron (str):
+ Required. Time specification (cron schedule expression) to launch scheduled runs.
+ To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}".
+ The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database.
+ For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *".
+ start_time (str):
+ Optional. Timestamp after which the first run can be scheduled.
+ If unspecified, it defaults to the schedule creation timestamp.
+ end_time (str):
+ Optional. Timestamp after which no more runs will be scheduled.
+ If unspecified, then runs will be scheduled indefinitely.
+ allow_queueing (bool):
+ Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached.
+ max_run_count (int):
+ Optional. Maximum run count of the schedule.
+ If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached.
+ Must be positive and <= 2^63-1.
+ max_concurrent_run_count (int):
+ Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule.
+ service_account (str):
+ Optional. Specifies the service account for workload run-as account.
+ Users submitting jobs must have act-as permission on this run-as account.
+ network (str):
+ Optional. The full name of the Compute Engine network to which the job
+ should be peered. For example, projects/12345/global/networks/myVPC.
+ Private services access must already be configured for the network.
+ If left unspecified, the network set in aiplatform.init will be used.
+ Otherwise, the job is not peered with any network.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ """
+ if cron:
+ self._gca_resource.cron = cron
+ if start_time:
+ self._gca_resource.start_time = start_time
+ if end_time:
+ self._gca_resource.end_time = end_time
+ if allow_queueing:
+ self._gca_resource.allow_queueing = allow_queueing
+ if max_run_count:
+ self._gca_resource.max_run_count = max_run_count
+ if max_concurrent_run_count:
+ self._gca_resource.max_concurrent_run_count = max_concurrent_run_count
+
+ network = network or initializer.global_config.network
+
+ if service_account:
+ self._gca_resource.create_pipeline_job_request.pipeline_job.service_account = (
+ service_account
+ )
+
+ if network:
+ self._gca_resource.create_pipeline_job_request.pipeline_job.network = (
+ network
+ )
+
+ _LOGGER.log_create_with_lro(self.__class__)
+
+ self._gca_resource = self.api_client.create_schedule(
+ parent=self._parent,
+ schedule=self._gca_resource,
+ timeout=create_request_timeout,
+ )
+
+ _LOGGER.log_create_complete_with_getter(
+ self.__class__, self._gca_resource, "schedule"
+ )
+
+ _LOGGER.info("View Schedule:\n%s" % self._dashboard_uri())
+
+ @classmethod
+ def list(
+ cls,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["PipelineJobSchedule"]:
+ """List all instances of this PipelineJobSchedule resource.
+
+ Example Usage:
+
+ aiplatform.PipelineJobSchedule.list(
+ filter='display_name="experiment_a27"',
+ order_by='create_time desc'
+ )
+
+ Args:
+ filter (str):
+ Optional. An expression for filtering the results of the request.
+ For field names both snake_case and camelCase are supported.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for descending.
+ Supported fields: `display_name`, `create_time`, `update_time`
+ project (str):
+ Optional. Project to retrieve list from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve list from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve list. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ List[PipelineJobSchedule] - A list of PipelineJobSchedule resource objects.
+ """
+ return cls._list_with_local_order(
+ filter=filter,
+ order_by=order_by,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ def list_jobs(
+ self,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ enable_simple_view: bool = True,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List[PipelineJob]:
+ """List all PipelineJob 's created by this PipelineJobSchedule.
+
+ Example usage:
+
+ pipeline_job_schedule.list_jobs(order_by='create_time_desc')
+
+ Args:
+ filter (str):
+ Optional. An expression for filtering the results of the request.
+ For field names both snake_case and camelCase are supported.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for descending.
+ Supported fields: `display_name`, `create_time`, `update_time`
+ enable_simple_view (bool):
+ Optional. Whether to pass the `read_mask` parameter to the list call.
+ Defaults to True if not provided. This will improve the performance of calling
+ list(). However, the returned PipelineJob list will not include all fields for
+ each PipelineJob. Setting this to True will exclude the following fields in your
+ response: `runtime_config`, `service_account`, `network`, and some subfields of
+ `pipeline_spec` and `job_detail`. The following fields will be included in
+ each PipelineJob resource in your response: `state`, `display_name`,
+ `pipeline_spec.pipeline_info`, `create_time`, `start_time`, `end_time`,
+ `update_time`, `labels`, `template_uri`, `template_metadata.version`,
+ `job_detail.pipeline_run_context`, `job_detail.pipeline_context`.
+ project (str):
+ Optional. Project to retrieve list from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve list from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve list. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ List[PipelineJob] - A list of PipelineJob resource objects.
+ """
+ list_filter = f"schedule_name={self._gca_resource.name}"
+ if filter:
+ list_filter = list_filter + f" AND {filter}"
+
+ return PipelineJob.list(
+ filter=list_filter,
+ order_by=order_by,
+ enable_simple_view=enable_simple_view,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ def update(
+ self,
+ display_name: Optional[str] = None,
+ cron: Optional[str] = None,
+ start_time: Optional[str] = None,
+ end_time: Optional[str] = None,
+ allow_queueing: Optional[bool] = None,
+ max_run_count: Optional[int] = None,
+ max_concurrent_run_count: Optional[int] = None,
+ ) -> None:
+ """Update an existing PipelineJobSchedule.
+
+ Example usage:
+
+ pipeline_job_schedule.update(
+ display_name='updated-display-name',
+ cron='* * * * *',
+ )
+
+ Args:
+ display_name (str):
+ Optional. The user-defined name of this PipelineJobSchedule.
+ cron (str):
+ Optional. Time specification (cron schedule expression) to launch scheduled runs.
+ To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}".
+ The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database.
+ For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *".
+ start_time (str):
+ Optional. Timestamp after which the first run can be scheduled.
+ If unspecified, it defaults to the schedule creation timestamp.
+ end_time (str):
+ Optional. Timestamp after which no more runs will be scheduled.
+ If unspecified, then runs will be scheduled indefinitely.
+ allow_queueing (bool):
+ Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached.
+ max_run_count (int):
+ Optional. Maximum run count of the schedule.
+ If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached.
+ Must be positive and <= 2^63-1.
+ max_concurrent_run_count (int):
+ Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule.
+
+ Raises:
+ RuntimeError: User tried to call update() before create().
+ """
+ pipeline_job_schedule = self._gca_resource
+ if pipeline_job_schedule.state in _SCHEDULE_ERROR_STATES:
+ raise RuntimeError(
+ "Not updating PipelineJobSchedule: PipelineJobSchedule must be active or completed."
+ )
+
+ updated_fields = []
+ if display_name is not None:
+ updated_fields.append("display_name")
+ setattr(pipeline_job_schedule, "display_name", display_name)
+ if cron is not None:
+ updated_fields.append("cron")
+ setattr(pipeline_job_schedule, "cron", cron)
+ if start_time is not None:
+ updated_fields.append("start_time")
+ setattr(pipeline_job_schedule, "start_time", start_time)
+ if end_time is not None:
+ updated_fields.append("end_time")
+ setattr(pipeline_job_schedule, "end_time", end_time)
+ if allow_queueing is not None:
+ updated_fields.append("allow_queueing")
+ setattr(pipeline_job_schedule, "allow_queueing", allow_queueing)
+ if max_run_count is not None:
+ updated_fields.append("max_run_count")
+ setattr(pipeline_job_schedule, "max_run_count", max_run_count)
+ if max_concurrent_run_count is not None:
+ updated_fields.append("max_concurrent_run_count")
+ setattr(
+ pipeline_job_schedule,
+ "max_concurrent_run_count",
+ max_concurrent_run_count,
+ )
+
+ update_mask = field_mask.FieldMask(paths=updated_fields)
+ self.api_client.update_schedule(
+ schedule=pipeline_job_schedule,
+ update_mask=update_mask,
+ )
diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py
index 3efa6c1e84..47c33f79b5 100644
--- a/google/cloud/aiplatform/pipeline_jobs.py
+++ b/google/cloud/aiplatform/pipeline_jobs.py
@@ -431,6 +431,92 @@ def submit(
if experiment:
self._associate_to_experiment(experiment)
+ def create_schedule(
+ self,
+ cron: str,
+ display_name: str,
+ start_time: Optional[str] = None,
+ end_time: Optional[str] = None,
+ allow_queueing: bool = False,
+ max_run_count: Optional[int] = None,
+ max_concurrent_run_count: int = 1,
+ service_account: Optional[str] = None,
+ network: Optional[str] = None,
+ create_request_timeout: Optional[float] = None,
+ ) -> "aiplatform.PipelineJobSchedule":
+ """Creates a PipelineJobSchedule directly from a PipelineJob.
+
+ Example Usage:
+
+ pipeline_job = aiplatform.PipelineJob(
+ display_name='job_display_name',
+ template_path='your_pipeline.yaml',
+ )
+ pipeline_job.run()
+ pipeline_job_schedule = pipeline_job.create_schedule(
+ cron='* * * * *',
+ display_name='schedule_display_name',
+ )
+
+ Args:
+ cron (str):
+ Required. Time specification (cron schedule expression) to launch scheduled runs.
+ To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}".
+ The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database.
+ For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *".
+ display_name (str):
+ Required. The user-defined name of this PipelineJobSchedule.
+ start_time (str):
+ Optional. Timestamp after which the first run can be scheduled.
+ If unspecified, it defaults to the schedule creation timestamp.
+ end_time (str):
+ Optional. Timestamp after which no more runs will be scheduled.
+ If unspecified, then runs will be scheduled indefinitely.
+ allow_queueing (bool):
+ Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached.
+ max_run_count (int):
+ Optional. Maximum run count of the schedule.
+ If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached.
+ Must be positive and <= 2^63-1.
+ max_concurrent_run_count (int):
+ Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule.
+ service_account (str):
+ Optional. Specifies the service account for workload run-as account.
+ Users submitting jobs must have act-as permission on this run-as account.
+ network (str):
+ Optional. The full name of the Compute Engine network to which the job
+ should be peered. For example, projects/12345/global/networks/myVPC.
+ Private services access must already be configured for the network.
+ If left unspecified, the network set in aiplatform.init will be used.
+ Otherwise, the job is not peered with any network.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+
+ Returns:
+ A Vertex AI PipelineJobSchedule.
+ """
+ if not display_name:
+ display_name = self._generate_display_name(prefix="PipelineJobSchedule")
+ utils.validate_display_name(display_name)
+
+ pipeline_job_schedule = aiplatform.PipelineJobSchedule(
+ pipeline_job=self,
+ display_name=display_name,
+ )
+
+ pipeline_job_schedule.create(
+ cron=cron,
+ start_time=start_time,
+ end_time=end_time,
+ allow_queueing=allow_queueing,
+ max_run_count=max_run_count,
+ max_concurrent_run_count=max_concurrent_run_count,
+ service_account=service_account,
+ network=network,
+ create_request_timeout=create_request_timeout,
+ )
+ return pipeline_job_schedule
+
def wait(self):
"""Wait for this PipelineJob to complete."""
if self._latest_future is None:
diff --git a/google/cloud/aiplatform/preview/jobs.py b/google/cloud/aiplatform/preview/jobs.py
index 35e611f802..7ba408db95 100644
--- a/google/cloud/aiplatform/preview/jobs.py
+++ b/google/cloud/aiplatform/preview/jobs.py
@@ -238,6 +238,7 @@ def submit(
experiment_run: Optional[Union["aiplatform.ExperimentRun", str]] = None,
tensorboard: Optional[str] = None,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> None:
"""Submit the configured CustomJob.
@@ -290,6 +291,10 @@ def submit(
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Raises:
ValueError:
@@ -310,11 +315,12 @@ def submit(
if network:
self._gca_resource.job_spec.network = network
- if timeout or restart_job_on_worker_restart:
+ if timeout or restart_job_on_worker_restart or disable_retries:
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
timeout=timeout,
restart_job_on_worker_restart=restart_job_on_worker_restart,
+ disable_retries=disable_retries,
)
if enable_web_access:
diff --git a/google/cloud/aiplatform/preview/pipelinejob/pipeline_jobs.py b/google/cloud/aiplatform/preview/pipelinejob/pipeline_jobs.py
index 8e1a58da30..2830d308ce 100644
--- a/google/cloud/aiplatform/preview/pipelinejob/pipeline_jobs.py
+++ b/google/cloud/aiplatform/preview/pipelinejob/pipeline_jobs.py
@@ -17,33 +17,17 @@
from typing import Optional
-from google.cloud.aiplatform import base
-from google.cloud.aiplatform import pipeline_jobs
-from google.cloud.aiplatform import utils
-from google.cloud.aiplatform.constants import pipeline as pipeline_constants
+from google.cloud.aiplatform.pipeline_jobs import (
+ PipelineJob as PipelineJobGa,
+)
+from google.cloud.aiplatform import pipeline_job_schedules
+
from google.cloud.aiplatform.metadata import constants as metadata_constants
from google.cloud.aiplatform.metadata import experiment_resources
-_LOGGER = base.Logger(__name__)
-
-_PIPELINE_COMPLETE_STATES = pipeline_constants._PIPELINE_COMPLETE_STATES
-
-_PIPELINE_ERROR_STATES = pipeline_constants._PIPELINE_ERROR_STATES
-
-# Pattern for valid names used as a Vertex resource name.
-_VALID_NAME_PATTERN = pipeline_constants._VALID_NAME_PATTERN
-
-# Pattern for an Artifact Registry URL.
-_VALID_AR_URL = pipeline_constants._VALID_AR_URL
-
-# Pattern for any JSON or YAML file over HTTPS.
-_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL
-
-_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS
-
class _PipelineJob(
- pipeline_jobs.PipelineJob,
+ PipelineJobGa,
experiment_loggable_schemas=(
experiment_resources._ExperimentLoggableSchema(
title=metadata_constants.SYSTEM_PIPELINE_RUN
@@ -116,21 +100,9 @@ def create_schedule(
Returns:
A Vertex AI PipelineJobSchedule.
"""
- from google.cloud.aiplatform.preview.pipelinejobschedule import (
- pipeline_job_schedules,
- )
-
- if not display_name:
- display_name = self._generate_display_name(prefix="PipelineJobSchedule")
- utils.validate_display_name(display_name)
-
- pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
- pipeline_job=self,
+ return super().create_schedule(
+ cron=cron_expression,
display_name=display_name,
- )
-
- pipeline_job_schedule.create(
- cron_expression=cron_expression,
start_time=start_time,
end_time=end_time,
allow_queueing=allow_queueing,
@@ -140,4 +112,3 @@ def create_schedule(
network=network,
create_request_timeout=create_request_timeout,
)
- return pipeline_job_schedule
diff --git a/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py b/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py
index 16a546b1f4..2c119c4922 100644
--- a/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py
+++ b/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py
@@ -18,45 +18,20 @@
from typing import List, Optional
from google.auth import credentials as auth_credentials
-from google.cloud.aiplatform import base
-from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import (
PipelineJob,
)
-from google.cloud.aiplatform import utils
-from google.cloud.aiplatform.compat.types import (
- schedule_v1beta1 as gca_schedule,
+from google.cloud.aiplatform.pipeline_job_schedules import (
+ PipelineJobSchedule as PipelineJobScheduleGa,
)
-from google.cloud.aiplatform.preview.constants import (
- schedules as schedule_constants,
+from google.cloud.aiplatform.preview.schedule.schedules import (
+ _Schedule as _SchedulePreview,
)
-from google.cloud.aiplatform.preview.schedule.schedules import _Schedule
-
-# TODO(b/283318141): Remove imports once PipelineJobSchedule is GA.
-from google.cloud.aiplatform_v1.types import (
- pipeline_job as gca_pipeline_job_v1,
-)
-from google.cloud.aiplatform_v1beta1.types import (
- pipeline_job as gca_pipeline_job_v1beta1,
-)
-from google.protobuf import field_mask_pb2 as field_mask
-
-_LOGGER = base.Logger(__name__)
-
-# Pattern for valid names used as a Vertex resource name.
-_VALID_NAME_PATTERN = schedule_constants._VALID_NAME_PATTERN
-
-# Pattern for an Artifact Registry URL.
-_VALID_AR_URL = schedule_constants._VALID_AR_URL
-
-# Pattern for any JSON or YAML file over HTTPS.
-_VALID_HTTPS_URL = schedule_constants._VALID_HTTPS_URL
-
-_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES
class PipelineJobSchedule(
- _Schedule,
+ PipelineJobScheduleGa,
+ _SchedulePreview,
):
def __init__(
self,
@@ -84,39 +59,13 @@ def __init__(
Optional. Location to create PipelineJobSchedule. If not set,
location set in aiplatform.init will be used.
"""
- if not display_name:
- display_name = self.__class__._generate_display_name()
- utils.validate_display_name(display_name)
-
- super().__init__(credentials=credentials, project=project, location=location)
-
- self._parent = initializer.global_config.common_location_path(
- project=project, location=location
- )
-
- # TODO(b/283318141): Remove temporary logic once PipelineJobSchedule is GA.
- runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig.deserialize(
- gca_pipeline_job_v1.PipelineJob.RuntimeConfig.serialize(
- pipeline_job.runtime_config
- )
+ super().__init__(
+ pipeline_job=pipeline_job,
+ display_name=display_name,
+ credentials=credentials,
+ project=project,
+ location=location,
)
- create_pipeline_job_request = {
- "parent": self._parent,
- "pipeline_job": {
- "runtime_config": runtime_config,
- "pipeline_spec": pipeline_job.pipeline_spec,
- },
- }
- if "template_uri" in pipeline_job._gca_resource:
- create_pipeline_job_request["pipeline_job"][
- "template_uri"
- ] = pipeline_job._gca_resource.template_uri
- pipeline_job_schedule_args = {
- "display_name": display_name,
- "create_pipeline_job_request": create_pipeline_job_request,
- }
-
- self._gca_resource = gca_schedule.Schedule(**pipeline_job_schedule_args)
def create(
self,
@@ -164,10 +113,8 @@ def create(
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
"""
- network = network or initializer.global_config.network
-
- self._create(
- cron_expression=cron_expression,
+ super().create(
+ cron=cron_expression,
start_time=start_time,
end_time=end_time,
allow_queueing=allow_queueing,
@@ -178,138 +125,6 @@ def create(
create_request_timeout=create_request_timeout,
)
- def _create(
- self,
- cron_expression: str,
- start_time: Optional[str] = None,
- end_time: Optional[str] = None,
- allow_queueing: bool = False,
- max_run_count: Optional[int] = None,
- max_concurrent_run_count: int = 1,
- service_account: Optional[str] = None,
- network: Optional[str] = None,
- create_request_timeout: Optional[float] = None,
- ) -> None:
- """Helper method to create the PipelineJobSchedule.
-
- Args:
- cron_expression (str):
- Required. Time specification (cron schedule expression) to launch scheduled runs.
- To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}".
- The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database.
- For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *".
- start_time (str):
- Optional. Timestamp after which the first run can be scheduled.
- If unspecified, it defaults to the schedule creation timestamp.
- end_time (str):
- Optional. Timestamp after which no more runs will be scheduled.
- If unspecified, then runs will be scheduled indefinitely.
- allow_queueing (bool):
- Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached.
- max_run_count (int):
- Optional. Maximum run count of the schedule.
- If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached.
- Must be positive and <= 2^63-1.
- max_concurrent_run_count (int):
- Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule.
- service_account (str):
- Optional. Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- network (str):
- Optional. The full name of the Compute Engine network to which the job
- should be peered. For example, projects/12345/global/networks/myVPC.
- Private services access must already be configured for the network.
- If left unspecified, the network set in aiplatform.init will be used.
- Otherwise, the job is not peered with any network.
- create_request_timeout (float):
- Optional. The timeout for the create request in seconds.
- """
- if cron_expression:
- self._gca_resource.cron = cron_expression
- if start_time:
- self._gca_resource.start_time = start_time
- if end_time:
- self._gca_resource.end_time = end_time
- if allow_queueing:
- self._gca_resource.allow_queueing = allow_queueing
- if max_run_count:
- self._gca_resource.max_run_count = max_run_count
- if max_concurrent_run_count:
- self._gca_resource.max_concurrent_run_count = max_concurrent_run_count
-
- network = network or initializer.global_config.network
-
- if service_account:
- self._gca_resource.create_pipeline_job_request.pipeline_job.service_account = (
- service_account
- )
-
- if network:
- self._gca_resource.create_pipeline_job_request.pipeline_job.network = (
- network
- )
-
- _LOGGER.log_create_with_lro(self.__class__)
-
- self._gca_resource = self.api_client.create_schedule(
- parent=self._parent,
- schedule=self._gca_resource,
- timeout=create_request_timeout,
- )
-
- _LOGGER.log_create_complete_with_getter(
- self.__class__, self._gca_resource, "schedule"
- )
-
- _LOGGER.info("View Schedule:\n%s" % self._dashboard_uri())
-
- @classmethod
- def list(
- cls,
- filter: Optional[str] = None,
- order_by: Optional[str] = None,
- project: Optional[str] = None,
- location: Optional[str] = None,
- credentials: Optional[auth_credentials.Credentials] = None,
- ) -> List["PipelineJobSchedule"]:
- """List all instances of this PipelineJobSchedule resource.
-
- Example Usage:
-
- aiplatform.PipelineJobSchedule.list(
- filter='display_name="experiment_a27"',
- order_by='create_time desc'
- )
-
- Args:
- filter (str):
- Optional. An expression for filtering the results of the request.
- For field names both snake_case and camelCase are supported.
- order_by (str):
- Optional. A comma-separated list of fields to order by, sorted in
- ascending order. Use "desc" after a field name for descending.
- Supported fields: `display_name`, `create_time`, `update_time`
- project (str):
- Optional. Project to retrieve list from. If not set, project
- set in aiplatform.init will be used.
- location (str):
- Optional. Location to retrieve list from. If not set, location
- set in aiplatform.init will be used.
- credentials (auth_credentials.Credentials):
- Optional. Custom credentials to use to retrieve list. Overrides
- credentials set in aiplatform.init.
-
- Returns:
- List[PipelineJobSchedule] - A list of PipelineJobSchedule resource objects.
- """
- return cls._list_with_local_order(
- filter=filter,
- order_by=order_by,
- project=project,
- location=location,
- credentials=credentials,
- )
-
def list_jobs(
self,
filter: Optional[str] = None,
@@ -357,12 +172,8 @@ def list_jobs(
Returns:
List[PipelineJob] - A list of PipelineJob resource objects.
"""
- list_filter = f"schedule_name={self._gca_resource.name}"
- if filter:
- list_filter = list_filter + f" AND {filter}"
-
- return PipelineJob.list(
- filter=list_filter,
+ return super().list_jobs(
+ filter=filter,
order_by=order_by,
enable_simple_view=enable_simple_view,
project=project,
@@ -386,7 +197,7 @@ def update(
pipeline_job_schedule.update(
display_name='updated-display-name',
- cron_expression='1 2 3 4 5',
+ cron_expression='* * * * *',
)
Args:
@@ -415,41 +226,12 @@ def update(
Raises:
RuntimeError: User tried to call update() before create().
"""
- pipeline_job_schedule = self._gca_resource
- if pipeline_job_schedule.state in _SCHEDULE_ERROR_STATES:
- raise RuntimeError(
- "Not updating PipelineJobSchedule: PipelineJobSchedule must be active or completed."
- )
-
- updated_fields = []
- if display_name is not None:
- updated_fields.append("display_name")
- setattr(pipeline_job_schedule, "display_name", display_name)
- if cron_expression is not None:
- updated_fields.append("cron")
- setattr(pipeline_job_schedule, "cron", cron_expression)
- if start_time is not None:
- updated_fields.append("start_time")
- setattr(pipeline_job_schedule, "start_time", start_time)
- if end_time is not None:
- updated_fields.append("end_time")
- setattr(pipeline_job_schedule, "end_time", end_time)
- if allow_queueing is not None:
- updated_fields.append("allow_queueing")
- setattr(pipeline_job_schedule, "allow_queueing", allow_queueing)
- if max_run_count is not None:
- updated_fields.append("max_run_count")
- setattr(pipeline_job_schedule, "max_run_count", max_run_count)
- if max_concurrent_run_count is not None:
- updated_fields.append("max_concurrent_run_count")
- setattr(
- pipeline_job_schedule,
- "max_concurrent_run_count",
- max_concurrent_run_count,
- )
-
- update_mask = field_mask.FieldMask(paths=updated_fields)
- self.api_client.update_schedule(
- schedule=pipeline_job_schedule,
- update_mask=update_mask,
+ super().update(
+ display_name=display_name,
+ cron=cron_expression,
+ start_time=start_time,
+ end_time=end_time,
+ allow_queueing=allow_queueing,
+ max_run_count=max_run_count,
+ max_concurrent_run_count=max_concurrent_run_count,
)
diff --git a/google/cloud/aiplatform/preview/schedule/schedules.py b/google/cloud/aiplatform/preview/schedule/schedules.py
index 8297db3f90..5c170a812b 100644
--- a/google/cloud/aiplatform/preview/schedule/schedules.py
+++ b/google/cloud/aiplatform/preview/schedule/schedules.py
@@ -15,44 +15,16 @@
# limitations under the License.
#
-import time
-from typing import Any, Optional
-
from google.auth import credentials as auth_credentials
-from google.cloud.aiplatform import base
-from google.cloud.aiplatform import utils
-from google.cloud.aiplatform.compat.types import (
- schedule_v1beta1 as gca_schedule,
-)
-from google.cloud.aiplatform.preview.constants import (
- schedules as schedule_constants,
-)
-
-_LOGGER = base.Logger(__name__)
-_SCHEDULE_COMPLETE_STATES = schedule_constants._SCHEDULE_COMPLETE_STATES
-
-_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES
+from google.cloud.aiplatform.schedules import _Schedule as _ScheduleGa
class _Schedule(
- base.VertexAiStatefulResource,
+ _ScheduleGa,
):
"""Preview Schedule resource for Vertex AI."""
- client_class = utils.ScheduleClientWithOverride
- _resource_noun = "schedules"
- _delete_method = "delete_schedule"
- _getter_method = "get_schedule"
- _list_method = "list_schedules"
- _pause_method = "pause_schedule"
- _resume_method = "resume_schedule"
- _parse_resource_name_method = "parse_schedule_path"
- _format_resource_name_method = "schedule_path"
-
- # Required by the done() method
- _valid_done_states = schedule_constants._SCHEDULE_COMPLETE_STATES
-
def __init__(
self,
credentials: auth_credentials.Credentials,
@@ -73,99 +45,6 @@ def __init__(
"""
super().__init__(project=project, location=location, credentials=credentials)
- @classmethod
- def get(
- cls,
- schedule_id: str,
- project: Optional[str] = None,
- location: Optional[str] = None,
- credentials: Optional[auth_credentials.Credentials] = None,
- ) -> Any:
- """Get a Vertex AI Schedule for the given resource_name.
-
- Args:
- schedule_id (str):
- Required. Schedule ID used to identify or locate the schedule.
- project (str):
- Optional. Project to retrieve dataset from. If not set, project
- set in aiplatform.init will be used.
- location (str):
- Optional. Location to retrieve dataset from. If not set,
- location set in aiplatform.init will be used.
- credentials (auth_credentials.Credentials):
- Optional. Custom credentials to use to upload this model.
- Overrides credentials set in aiplatform.init.
-
- Returns:
- A Vertex AI Schedule.
- """
- self = cls._empty_constructor(
- project=project,
- location=location,
- credentials=credentials,
- resource_name=schedule_id,
- )
-
- self._gca_resource = self._get_gca_resource(resource_name=schedule_id)
-
- return self
-
- def pause(self) -> None:
- """Starts asynchronous pause on the Schedule.
-
- Changes Schedule state from State.ACTIVE to State.PAUSED.
- """
- self.api_client.pause_schedule(name=self.resource_name)
-
- def resume(
- self,
- catch_up: bool = True,
- ) -> None:
- """Starts asynchronous resume on the Schedule.
-
- Changes Schedule state from State.PAUSED to State.ACTIVE.
-
- Args:
- catch_up (bool):
- Optional. Whether to backfill missed runs when the Schedule is
- resumed from State.PAUSED.
- """
- self.api_client.resume_schedule(name=self.resource_name)
-
- def done(self) -> bool:
- """Helper method that return True is Schedule is done. False otherwise."""
- if not self._gca_resource:
- return False
-
- return self.state in _SCHEDULE_COMPLETE_STATES
-
- def wait(self) -> None:
- """Wait for all runs scheduled by this Schedule to complete."""
- if self._latest_future is None:
- self._block_until_complete()
- else:
- super().wait()
-
- @property
- def state(self) -> Optional[gca_schedule.Schedule.State]:
- """Current Schedule state.
-
- Returns:
- Schedule state.
- """
- self._sync_gca_resource()
- return self._gca_resource.state
-
- @property
- def max_run_count(self) -> int:
- """Current Schedule max_run_count.
-
- Returns:
- Schedule max_run_count.
- """
- self._sync_gca_resource()
- return self._gca_resource.max_run_count
-
@property
def cron_expression(self) -> str:
"""Current Schedule cron expression.
@@ -173,66 +52,4 @@ def cron_expression(self) -> str:
Returns:
Schedule cron expression.
"""
- self._sync_gca_resource()
- return self._gca_resource.cron
-
- @property
- def max_concurrent_run_count(self) -> int:
- """Current Schedule max_concurrent_run_count.
-
- Returns:
- Schedule max_concurrent_run_count.
- """
- self._sync_gca_resource()
- return self._gca_resource.max_concurrent_run_count
-
- @property
- def allow_queueing(self) -> bool:
- """Whether current Schedule allows queueing.
-
- Returns:
- Schedule allow_queueing.
- """
- self._sync_gca_resource()
- return self._gca_resource.allow_queueing
-
- def _block_until_complete(self) -> None:
- """Helper method to block and check on Schedule until complete."""
- # Used these numbers so failures surface fast
- wait = 5 # start at five seconds
- log_wait = 5
- max_wait = 60 * 5 # 5 minute wait
- multiplier = 2 # scale wait by 2 every iteration
-
- previous_time = time.time()
- while self.state not in _SCHEDULE_COMPLETE_STATES:
- current_time = time.time()
- if current_time - previous_time >= log_wait:
- _LOGGER.info(
- "%s %s current state:\n%s"
- % (
- self.__class__.__name__,
- self._gca_resource.name,
- self._gca_resource.state,
- )
- )
- log_wait = min(log_wait * multiplier, max_wait)
- previous_time = current_time
- time.sleep(wait)
-
- # Error is only populated when the schedule state is STATE_UNSPECIFIED.
- if self._gca_resource.state in _SCHEDULE_ERROR_STATES:
- raise RuntimeError("Schedule failed with:\n%s" % self._gca_resource.error)
- else:
- _LOGGER.log_action_completed_against_resource("run", "completed", self)
-
- def _dashboard_uri(self) -> str:
- """Helper method to compose the dashboard uri where Schedule can be
- viewed.
-
- Returns:
- Dashboard uri where Schedule can be viewed.
- """
- fields = self._parse_resource_name(self.resource_name)
- url = f"https://console.cloud.google.com/vertex-ai/locations/{fields['location']}/pipelines/schedules/{fields['schedule']}?project={fields['project']}"
- return url
+ return super().cron
diff --git a/google/cloud/aiplatform/schedules.py b/google/cloud/aiplatform/schedules.py
new file mode 100644
index 0000000000..2083a60dbc
--- /dev/null
+++ b/google/cloud/aiplatform/schedules.py
@@ -0,0 +1,238 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from typing import Any, Optional
+
+from google.auth import credentials as auth_credentials
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.compat.types import (
+ schedule as gca_schedule,
+)
+from google.cloud.aiplatform.constants import (
+ schedule as schedule_constants,
+)
+
+_LOGGER = base.Logger(__name__)
+
+_SCHEDULE_COMPLETE_STATES = schedule_constants._SCHEDULE_COMPLETE_STATES
+
+_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES
+
+
+class _Schedule(
+ base.VertexAiStatefulResource,
+):
+ """Schedule resource for Vertex AI."""
+
+ client_class = utils.ScheduleClientWithOverride
+ _resource_noun = "schedules"
+ _delete_method = "delete_schedule"
+ _getter_method = "get_schedule"
+ _list_method = "list_schedules"
+ _pause_method = "pause_schedule"
+ _resume_method = "resume_schedule"
+ _parse_resource_name_method = "parse_schedule_path"
+ _format_resource_name_method = "schedule_path"
+
+ # Required by the done() method
+ _valid_done_states = schedule_constants._SCHEDULE_COMPLETE_STATES
+
+ def __init__(
+ self,
+ credentials: auth_credentials.Credentials,
+ project: str,
+ location: str,
+ ):
+ """Retrieves a Schedule resource and instantiates its representation.
+ Args:
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create this Schedule.
+ Overrides credentials set in aiplatform.init.
+ project (str):
+ Optional. The project that you want to run this Schedule in.
+ If not set, the project set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create Schedule. If not set,
+ location set in aiplatform.init will be used.
+ """
+ super().__init__(project=project, location=location, credentials=credentials)
+
+ @classmethod
+ def get(
+ cls,
+ schedule_id: str,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> Any:
+ """Get a Vertex AI Schedule for the given resource_name.
+
+ Args:
+ schedule_id (str):
+ Required. Schedule ID used to identify or locate the schedule.
+ project (str):
+ Optional. Project to retrieve dataset from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve dataset from. If not set,
+ location set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this model.
+ Overrides credentials set in aiplatform.init.
+
+ Returns:
+ A Vertex AI Schedule.
+ """
+ self = cls._empty_constructor(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=schedule_id,
+ )
+
+ self._gca_resource = self._get_gca_resource(resource_name=schedule_id)
+
+ return self
+
+ def pause(self) -> None:
+ """Starts asynchronous pause on the Schedule.
+
+ Changes Schedule state from State.ACTIVE to State.PAUSED.
+ """
+ self.api_client.pause_schedule(name=self.resource_name)
+
+ def resume(
+ self,
+ catch_up: bool = True,
+ ) -> None:
+ """Starts asynchronous resume on the Schedule.
+
+ Changes Schedule state from State.PAUSED to State.ACTIVE.
+
+ Args:
+ catch_up (bool):
+ Optional. Whether to backfill missed runs when the Schedule is
+ resumed from State.PAUSED.
+ """
+ self.api_client.resume_schedule(name=self.resource_name)
+
+ def done(self) -> bool:
+ """Helper method that return True is Schedule is done. False otherwise."""
+ if not self._gca_resource:
+ return False
+
+ return self.state in _SCHEDULE_COMPLETE_STATES
+
+ def wait(self) -> None:
+ """Wait for all runs scheduled by this Schedule to complete."""
+ if self._latest_future is None:
+ self._block_until_complete()
+ else:
+ super().wait()
+
+ @property
+ def state(self) -> Optional[gca_schedule.Schedule.State]:
+ """Current Schedule state.
+
+ Returns:
+ Schedule state.
+ """
+ self._sync_gca_resource()
+ return self._gca_resource.state
+
+ @property
+ def max_run_count(self) -> int:
+ """Current Schedule max_run_count.
+
+ Returns:
+ Schedule max_run_count.
+ """
+ self._sync_gca_resource()
+ return self._gca_resource.max_run_count
+
+ @property
+ def cron(self) -> str:
+ """Current Schedule cron.
+
+ Returns:
+ Schedule cron.
+ """
+ self._sync_gca_resource()
+ return self._gca_resource.cron
+
+ @property
+ def max_concurrent_run_count(self) -> int:
+ """Current Schedule max_concurrent_run_count.
+
+ Returns:
+ Schedule max_concurrent_run_count.
+ """
+ self._sync_gca_resource()
+ return self._gca_resource.max_concurrent_run_count
+
+ @property
+ def allow_queueing(self) -> bool:
+ """Whether current Schedule allows queueing.
+
+ Returns:
+ Schedule allow_queueing.
+ """
+ self._sync_gca_resource()
+ return self._gca_resource.allow_queueing
+
+ def _block_until_complete(self) -> None:
+ """Helper method to block and check on Schedule until complete."""
+ # Used these numbers so failures surface fast
+ wait = 5 # start at five seconds
+ log_wait = 5
+ max_wait = 60 * 5 # 5 minute wait
+ multiplier = 2 # scale wait by 2 every iteration
+
+ previous_time = time.time()
+ while self.state not in _SCHEDULE_COMPLETE_STATES:
+ current_time = time.time()
+ if current_time - previous_time >= log_wait:
+ _LOGGER.info(
+ "%s %s current state:\n%s"
+ % (
+ self.__class__.__name__,
+ self._gca_resource.name,
+ self._gca_resource.state,
+ )
+ )
+ log_wait = min(log_wait * multiplier, max_wait)
+ previous_time = current_time
+ time.sleep(wait)
+
+ # Error is only populated when the schedule state is STATE_UNSPECIFIED.
+ if self._gca_resource.state in _SCHEDULE_ERROR_STATES:
+ raise RuntimeError("Schedule failed with:\n%s" % self._gca_resource.error)
+ else:
+ _LOGGER.log_action_completed_against_resource("run", "completed", self)
+
+ def _dashboard_uri(self) -> str:
+ """Helper method to compose the dashboard uri where Schedule can be
+ viewed.
+
+ Returns:
+ Dashboard uri where Schedule can be viewed.
+ """
+ fields = self._parse_resource_name(self.resource_name)
+ url = f"https://console.cloud.google.com/vertex-ai/locations/{fields['location']}/pipelines/schedules/{fields['schedule']}?project={fields['project']}"
+ return url
diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py
index 0cfb28c462..7af003d185 100644
--- a/google/cloud/aiplatform/training_jobs.py
+++ b/google/cloud/aiplatform/training_jobs.py
@@ -1488,6 +1488,7 @@ def _prepare_training_task_inputs_and_output_dir(
enable_web_access: bool = False,
enable_dashboard_access: bool = False,
tensorboard: Optional[str] = None,
+ disable_retries: bool = False,
) -> Tuple[Dict, str]:
"""Prepares training task inputs and output directory for custom job.
@@ -1534,6 +1535,10 @@ def _prepare_training_task_inputs_and_output_dir(
`service_account` is required with provided `tensorboard`.
For more information on configuring your service account please visit:
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
Training task inputs and Output directory for custom job.
"""
@@ -1561,11 +1566,12 @@ def _prepare_training_task_inputs_and_output_dir(
if enable_dashboard_access:
training_task_inputs["enable_dashboard_access"] = enable_dashboard_access
- if timeout or restart_job_on_worker_restart:
+ if timeout or restart_job_on_worker_restart or disable_retries:
timeout = f"{timeout}s" if timeout else None
scheduling = {
"timeout": timeout,
"restart_job_on_worker_restart": restart_job_on_worker_restart,
+ "disable_retries": disable_retries,
}
training_task_inputs["scheduling"] = scheduling
@@ -2923,6 +2929,7 @@ def run(
tensorboard: Optional[str] = None,
sync=True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> Optional[models.Model]:
"""Runs the custom training job.
@@ -3206,6 +3213,10 @@ def run(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
model: The trained Vertex AI Model resource or None if training did not
@@ -3266,6 +3277,7 @@ def run(
else None,
sync=sync,
create_request_timeout=create_request_timeout,
+ disable_retries=disable_retries,
)
def submit(
@@ -3316,6 +3328,7 @@ def submit(
tensorboard: Optional[str] = None,
sync=True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> Optional[models.Model]:
"""Submits the custom training job without blocking until completion.
@@ -3599,6 +3612,10 @@ def submit(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
model: The trained Vertex AI Model resource or None if training did not
@@ -3660,6 +3677,7 @@ def submit(
sync=sync,
create_request_timeout=create_request_timeout,
block=False,
+ disable_retries=disable_retries,
)
@base.optional_sync(construct_object_on_arg="managed_model")
@@ -3705,6 +3723,7 @@ def _run(
sync=True,
create_request_timeout: Optional[float] = None,
block: Optional[bool] = True,
+ disable_retries: bool = False,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
@@ -3890,6 +3909,10 @@ def _run(
Optional. The timeout for the create request in seconds
block (bool):
Optional. If True, block until complete.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
model: The trained Vertex AI Model resource or None if training did not
@@ -3942,6 +3965,7 @@ def _run(
enable_web_access=enable_web_access,
enable_dashboard_access=enable_dashboard_access,
tensorboard=tensorboard,
+ disable_retries=disable_retries,
)
model = self._run_job(
@@ -4263,6 +4287,7 @@ def run(
tensorboard: Optional[str] = None,
sync=True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> Optional[models.Model]:
"""Runs the custom training job.
@@ -4539,6 +4564,10 @@ def run(
be immediately returned and synced when the Future has completed.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
model: The trained Vertex AI Model resource or None if training did not
@@ -4598,6 +4627,7 @@ def run(
else None,
sync=sync,
create_request_timeout=create_request_timeout,
+ disable_retries=disable_retries,
)
def submit(
@@ -4648,6 +4678,7 @@ def submit(
tensorboard: Optional[str] = None,
sync=True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> Optional[models.Model]:
"""Submits the custom training job without blocking until completion.
@@ -4924,6 +4955,10 @@ def submit(
be immediately returned and synced when the Future has completed.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
model: The trained Vertex AI Model resource or None if training did not
@@ -4984,6 +5019,7 @@ def submit(
sync=sync,
create_request_timeout=create_request_timeout,
block=False,
+ disable_retries=disable_retries,
)
@base.optional_sync(construct_object_on_arg="managed_model")
@@ -5028,6 +5064,7 @@ def _run(
sync=True,
create_request_timeout: Optional[float] = None,
block: Optional[bool] = True,
+ disable_retries: bool = False,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
Args:
@@ -5209,6 +5246,10 @@ def _run(
Optional. The timeout for the create request in seconds.
block (bool):
Optional. If True, block until complete.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
model: The trained Vertex AI Model resource or None if training did not
@@ -5255,6 +5296,7 @@ def _run(
enable_web_access=enable_web_access,
enable_dashboard_access=enable_dashboard_access,
tensorboard=tensorboard,
+ disable_retries=disable_retries,
)
model = self._run_job(
@@ -7172,6 +7214,7 @@ def run(
tensorboard: Optional[str] = None,
sync=True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> Optional[models.Model]:
"""Runs the custom training job.
@@ -7448,6 +7491,10 @@ def run(
be immediately returned and synced when the Future has completed.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
model: The trained Vertex AI Model resource or None if training did not
@@ -7502,6 +7549,7 @@ def run(
else None,
sync=sync,
create_request_timeout=create_request_timeout,
+ disable_retries=disable_retries,
)
@base.optional_sync(construct_object_on_arg="managed_model")
@@ -7545,6 +7593,7 @@ def _run(
reduction_server_container_uri: Optional[str] = None,
sync=True,
create_request_timeout: Optional[float] = None,
+ disable_retries: bool = False,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
@@ -7711,6 +7760,10 @@ def _run(
be immediately returned and synced when the Future has completed.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
+ disable_retries (bool):
+ Indicates if the job should retry for internal errors after the
+ job starts running. If True, overrides
+ `restart_job_on_worker_restart` to False.
Returns:
model: The trained Vertex AI Model resource or None if training did not
@@ -7757,6 +7810,7 @@ def _run(
enable_web_access=enable_web_access,
enable_dashboard_access=enable_dashboard_access,
tensorboard=tensorboard,
+ disable_retries=disable_retries,
)
model = self._run_job(
diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py
index d6ccf17e6f..a5dea7a112 100644
--- a/google/cloud/aiplatform/utils/__init__.py
+++ b/google/cloud/aiplatform/utils/__init__.py
@@ -53,6 +53,7 @@
tensorboard_service_client_v1beta1,
vizier_service_client_v1beta1,
model_garden_service_client_v1beta1,
+ persistent_resource_service_client_v1beta1,
)
from google.cloud.aiplatform.compat.services import (
dataset_service_client_v1,
@@ -67,6 +68,7 @@
model_service_client_v1,
pipeline_service_client_v1,
prediction_service_client_v1,
+ schedule_service_client_v1,
tensorboard_service_client_v1,
vizier_service_client_v1,
)
@@ -104,6 +106,7 @@
prediction_service_client_v1.PredictionServiceClient,
pipeline_service_client_v1.PipelineServiceClient,
job_service_client_v1.JobServiceClient,
+ schedule_service_client_v1.ScheduleServiceClient,
tensorboard_service_client_v1.TensorboardServiceClient,
vizier_service_client_v1.VizierServiceClient,
)
@@ -597,8 +600,9 @@ class PipelineJobClientWithOverride(ClientWithOverride):
class ScheduleClientWithOverride(ClientWithOverride):
_is_temporary = True
- _default_version = compat.V1BETA1
+ _default_version = compat.DEFAULT_VERSION
_version_map = (
+ (compat.V1, schedule_service_client_v1.ScheduleServiceClient),
(compat.V1BETA1, schedule_service_client_v1beta1.ScheduleServiceClient),
)
@@ -654,6 +658,17 @@ class ModelGardenClientWithOverride(ClientWithOverride):
)
+class PersistentResourceClientWithOverride(ClientWithOverride):
+ _is_temporary = True
+ _default_version = compat.V1BETA1
+ _version_map = (
+ (
+ compat.V1BETA1,
+ persistent_resource_service_client_v1beta1.PersistentResourceServiceClient,
+ ),
+ )
+
+
VertexAiServiceClientWithOverride = TypeVar(
"VertexAiServiceClientWithOverride",
DatasetClientWithOverride,
@@ -670,6 +685,7 @@ class ModelGardenClientWithOverride(ClientWithOverride):
TensorboardClientWithOverride,
VizierClientWithOverride,
ModelGardenClientWithOverride,
+ PersistentResourceClientWithOverride,
)
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py
index a67882433e..9b7bdf307b 100644
--- a/google/cloud/aiplatform/version.py
+++ b/google/cloud/aiplatform/version.py
@@ -15,4 +15,4 @@
# limitations under the License.
#
-__version__ = "1.30.1"
+__version__ = "1.31.0"
diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py
index 8b53d0eabb..3ac4a1d256 100644
--- a/google/cloud/aiplatform_v1/__init__.py
+++ b/google/cloud/aiplatform_v1/__init__.py
@@ -83,6 +83,7 @@
from .types.dataset import ImportDataConfig
from .types.dataset_service import CreateDatasetOperationMetadata
from .types.dataset_service import CreateDatasetRequest
+from .types.dataset_service import CreateDatasetVersionOperationMetadata
from .types.dataset_service import DataItemView
from .types.dataset_service import DeleteDatasetRequest
from .types.dataset_service import DeleteSavedQueryRequest
@@ -102,6 +103,7 @@
from .types.dataset_service import ListDatasetsResponse
from .types.dataset_service import ListSavedQueriesRequest
from .types.dataset_service import ListSavedQueriesResponse
+from .types.dataset_service import RestoreDatasetVersionOperationMetadata
from .types.dataset_service import SearchDataItemsRequest
from .types.dataset_service import SearchDataItemsResponse
from .types.dataset_service import UpdateDatasetRequest
@@ -532,6 +534,8 @@
from .types.tensorboard_service import ListTensorboardTimeSeriesResponse
from .types.tensorboard_service import ReadTensorboardBlobDataRequest
from .types.tensorboard_service import ReadTensorboardBlobDataResponse
+from .types.tensorboard_service import ReadTensorboardSizeRequest
+from .types.tensorboard_service import ReadTensorboardSizeResponse
from .types.tensorboard_service import ReadTensorboardTimeSeriesDataRequest
from .types.tensorboard_service import ReadTensorboardTimeSeriesDataResponse
from .types.tensorboard_service import ReadTensorboardUsageRequest
@@ -669,6 +673,7 @@
"CreateDataLabelingJobRequest",
"CreateDatasetOperationMetadata",
"CreateDatasetRequest",
+ "CreateDatasetVersionOperationMetadata",
"CreateEndpointOperationMetadata",
"CreateEndpointRequest",
"CreateEntityTypeOperationMetadata",
@@ -1029,6 +1034,8 @@
"ReadIndexDatapointsResponse",
"ReadTensorboardBlobDataRequest",
"ReadTensorboardBlobDataResponse",
+ "ReadTensorboardSizeRequest",
+ "ReadTensorboardSizeResponse",
"ReadTensorboardTimeSeriesDataRequest",
"ReadTensorboardTimeSeriesDataResponse",
"ReadTensorboardUsageRequest",
@@ -1038,6 +1045,7 @@
"RemoveDatapointsRequest",
"RemoveDatapointsResponse",
"ResourcesConsumed",
+ "RestoreDatasetVersionOperationMetadata",
"ResumeModelDeploymentMonitoringJobRequest",
"ResumeScheduleRequest",
"SampleConfig",
diff --git a/google/cloud/aiplatform_v1/gapic_metadata.json b/google/cloud/aiplatform_v1/gapic_metadata.json
index c100fe1214..64926ddcc3 100644
--- a/google/cloud/aiplatform_v1/gapic_metadata.json
+++ b/google/cloud/aiplatform_v1/gapic_metadata.json
@@ -1925,6 +1925,90 @@
}
}
},
+ "ScheduleService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "ScheduleServiceClient",
+ "rpcs": {
+ "CreateSchedule": {
+ "methods": [
+ "create_schedule"
+ ]
+ },
+ "DeleteSchedule": {
+ "methods": [
+ "delete_schedule"
+ ]
+ },
+ "GetSchedule": {
+ "methods": [
+ "get_schedule"
+ ]
+ },
+ "ListSchedules": {
+ "methods": [
+ "list_schedules"
+ ]
+ },
+ "PauseSchedule": {
+ "methods": [
+ "pause_schedule"
+ ]
+ },
+ "ResumeSchedule": {
+ "methods": [
+ "resume_schedule"
+ ]
+ },
+ "UpdateSchedule": {
+ "methods": [
+ "update_schedule"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "ScheduleServiceAsyncClient",
+ "rpcs": {
+ "CreateSchedule": {
+ "methods": [
+ "create_schedule"
+ ]
+ },
+ "DeleteSchedule": {
+ "methods": [
+ "delete_schedule"
+ ]
+ },
+ "GetSchedule": {
+ "methods": [
+ "get_schedule"
+ ]
+ },
+ "ListSchedules": {
+ "methods": [
+ "list_schedules"
+ ]
+ },
+ "PauseSchedule": {
+ "methods": [
+ "pause_schedule"
+ ]
+ },
+ "ResumeSchedule": {
+ "methods": [
+ "resume_schedule"
+ ]
+ },
+ "UpdateSchedule": {
+ "methods": [
+ "update_schedule"
+ ]
+ }
+ }
+ }
+ }
+ },
"SpecialistPoolService": {
"clients": {
"grpc": {
@@ -2099,6 +2183,11 @@
"read_tensorboard_blob_data"
]
},
+ "ReadTensorboardSize": {
+ "methods": [
+ "read_tensorboard_size"
+ ]
+ },
"ReadTensorboardTimeSeriesData": {
"methods": [
"read_tensorboard_time_series_data"
@@ -2249,6 +2338,11 @@
"read_tensorboard_blob_data"
]
},
+ "ReadTensorboardSize": {
+ "methods": [
+ "read_tensorboard_size"
+ ]
+ },
"ReadTensorboardTimeSeriesData": {
"methods": [
"read_tensorboard_time_series_data"
diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform_v1/gapic_version.py
+++ b/google/cloud/aiplatform_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py
index 505760b293..c3f705d9e5 100644
--- a/google/cloud/aiplatform_v1/services/prediction_service/client.py
+++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py
@@ -16,8 +16,6 @@
from collections import OrderedDict
import os
import re
-
-import pkg_resources
from typing import (
Dict,
Mapping,
diff --git a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py
index bed5bc7db6..8277873dbc 100644
--- a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py
+++ b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py
@@ -975,6 +975,7 @@ async def sample_update_schedule():
the server. The following restrictions will be applied:
- The scheduled request type cannot be changed.
+ - The non-empty fields cannot be unset.
- The output_only fields will be ignored if specified.
This corresponds to the ``schedule`` field
diff --git a/google/cloud/aiplatform_v1/services/schedule_service/client.py b/google/cloud/aiplatform_v1/services/schedule_service/client.py
index d6249545b2..03bc1a247e 100644
--- a/google/cloud/aiplatform_v1/services/schedule_service/client.py
+++ b/google/cloud/aiplatform_v1/services/schedule_service/client.py
@@ -1323,6 +1323,7 @@ def sample_update_schedule():
the server. The following restrictions will be applied:
- The scheduled request type cannot be changed.
+ - The non-empty fields cannot be unset.
- The output_only fields will be ignored if specified.
This corresponds to the ``schedule`` field
diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py b/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py
index 94408634d7..3de1006682 100644
--- a/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py
+++ b/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py
@@ -992,6 +992,113 @@ async def sample_read_tensorboard_usage():
# Done; return the response.
return response
+ async def read_tensorboard_size(
+ self,
+ request: Optional[
+ Union[tensorboard_service.ReadTensorboardSizeRequest, dict]
+ ] = None,
+ *,
+ tensorboard: Optional[str] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: Union[float, object] = gapic_v1.method.DEFAULT,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> tensorboard_service.ReadTensorboardSizeResponse:
+ r"""Returns the storage size for a given TensorBoard
+ instance.
+
+ .. code-block:: python
+
+ # This snippet has been automatically generated and should be regarded as a
+ # code template only.
+ # It will require modifications to work:
+ # - It may require correct/in-range values for request initialization.
+ # - It may require specifying regional endpoints when creating the service
+ # client as shown in:
+ # https://googleapis.dev/python/google-api-core/latest/client_options.html
+ from google.cloud import aiplatform_v1
+
+ async def sample_read_tensorboard_size():
+ # Create a client
+ client = aiplatform_v1.TensorboardServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ReadTensorboardSizeRequest(
+ tensorboard="tensorboard_value",
+ )
+
+ # Make the request
+ response = await client.read_tensorboard_size(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Optional[Union[google.cloud.aiplatform_v1.types.ReadTensorboardSizeRequest, dict]]):
+ The request object. Request message for
+ [TensorboardService.ReadTensorboardSize][google.cloud.aiplatform.v1.TensorboardService.ReadTensorboardSize].
+ tensorboard (:class:`str`):
+ Required. The name of the Tensorboard resource. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ This corresponds to the ``tensorboard`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.ReadTensorboardSizeResponse:
+ Response message for
+ [TensorboardService.ReadTensorboardSize][google.cloud.aiplatform.v1.TensorboardService.ReadTensorboardSize].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([tensorboard])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = tensorboard_service.ReadTensorboardSizeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if tensorboard is not None:
+ request.tensorboard = tensorboard
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.read_tensorboard_size,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("tensorboard", request.tensorboard),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
async def create_tensorboard_experiment(
self,
request: Optional[
diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1/services/tensorboard_service/client.py
index 2a9ce909c4..34f95b1bac 100644
--- a/google/cloud/aiplatform_v1/services/tensorboard_service/client.py
+++ b/google/cloud/aiplatform_v1/services/tensorboard_service/client.py
@@ -1279,6 +1279,113 @@ def sample_read_tensorboard_usage():
# Done; return the response.
return response
+ def read_tensorboard_size(
+ self,
+ request: Optional[
+ Union[tensorboard_service.ReadTensorboardSizeRequest, dict]
+ ] = None,
+ *,
+ tensorboard: Optional[str] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: Union[float, object] = gapic_v1.method.DEFAULT,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> tensorboard_service.ReadTensorboardSizeResponse:
+ r"""Returns the storage size for a given TensorBoard
+ instance.
+
+ .. code-block:: python
+
+ # This snippet has been automatically generated and should be regarded as a
+ # code template only.
+ # It will require modifications to work:
+ # - It may require correct/in-range values for request initialization.
+ # - It may require specifying regional endpoints when creating the service
+ # client as shown in:
+ # https://googleapis.dev/python/google-api-core/latest/client_options.html
+ from google.cloud import aiplatform_v1
+
+ def sample_read_tensorboard_size():
+ # Create a client
+ client = aiplatform_v1.TensorboardServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ReadTensorboardSizeRequest(
+ tensorboard="tensorboard_value",
+ )
+
+ # Make the request
+ response = client.read_tensorboard_size(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.ReadTensorboardSizeRequest, dict]):
+ The request object. Request message for
+ [TensorboardService.ReadTensorboardSize][google.cloud.aiplatform.v1.TensorboardService.ReadTensorboardSize].
+ tensorboard (str):
+ Required. The name of the Tensorboard resource. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ This corresponds to the ``tensorboard`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.ReadTensorboardSizeResponse:
+ Response message for
+ [TensorboardService.ReadTensorboardSize][google.cloud.aiplatform.v1.TensorboardService.ReadTensorboardSize].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([tensorboard])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a tensorboard_service.ReadTensorboardSizeRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, tensorboard_service.ReadTensorboardSizeRequest):
+ request = tensorboard_service.ReadTensorboardSizeRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if tensorboard is not None:
+ request.tensorboard = tensorboard
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.read_tensorboard_size]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("tensorboard", request.tensorboard),)
+ ),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
def create_tensorboard_experiment(
self,
request: Optional[
diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/base.py b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/base.py
index 37f4db9472..fb3b4d7daf 100644
--- a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/base.py
+++ b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/base.py
@@ -173,6 +173,11 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=None,
client_info=client_info,
),
+ self.read_tensorboard_size: gapic_v1.method.wrap_method(
+ self.read_tensorboard_size,
+ default_timeout=None,
+ client_info=client_info,
+ ),
self.create_tensorboard_experiment: gapic_v1.method.wrap_method(
self.create_tensorboard_experiment,
default_timeout=None,
@@ -364,6 +369,18 @@ def read_tensorboard_usage(
]:
raise NotImplementedError()
+ @property
+ def read_tensorboard_size(
+ self,
+ ) -> Callable[
+ [tensorboard_service.ReadTensorboardSizeRequest],
+ Union[
+ tensorboard_service.ReadTensorboardSizeResponse,
+ Awaitable[tensorboard_service.ReadTensorboardSizeResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
@property
def create_tensorboard_experiment(
self,
diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py
index e7ad57fce9..9b8ef6fc72 100644
--- a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py
+++ b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py
@@ -430,6 +430,36 @@ def read_tensorboard_usage(
)
return self._stubs["read_tensorboard_usage"]
+ @property
+ def read_tensorboard_size(
+ self,
+ ) -> Callable[
+ [tensorboard_service.ReadTensorboardSizeRequest],
+ tensorboard_service.ReadTensorboardSizeResponse,
+ ]:
+ r"""Return a callable for the read tensorboard size method over gRPC.
+
+ Returns the storage size for a given TensorBoard
+ instance.
+
+ Returns:
+ Callable[[~.ReadTensorboardSizeRequest],
+ ~.ReadTensorboardSizeResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "read_tensorboard_size" not in self._stubs:
+ self._stubs["read_tensorboard_size"] = self.grpc_channel.unary_unary(
+ "/google.cloud.aiplatform.v1.TensorboardService/ReadTensorboardSize",
+ request_serializer=tensorboard_service.ReadTensorboardSizeRequest.serialize,
+ response_deserializer=tensorboard_service.ReadTensorboardSizeResponse.deserialize,
+ )
+ return self._stubs["read_tensorboard_size"]
+
@property
def create_tensorboard_experiment(
self,
diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py
index 86438e49a7..dca4f30bc5 100644
--- a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py
+++ b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py
@@ -440,6 +440,36 @@ def read_tensorboard_usage(
)
return self._stubs["read_tensorboard_usage"]
+ @property
+ def read_tensorboard_size(
+ self,
+ ) -> Callable[
+ [tensorboard_service.ReadTensorboardSizeRequest],
+ Awaitable[tensorboard_service.ReadTensorboardSizeResponse],
+ ]:
+ r"""Return a callable for the read tensorboard size method over gRPC.
+
+ Returns the storage size for a given TensorBoard
+ instance.
+
+ Returns:
+ Callable[[~.ReadTensorboardSizeRequest],
+ Awaitable[~.ReadTensorboardSizeResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "read_tensorboard_size" not in self._stubs:
+ self._stubs["read_tensorboard_size"] = self.grpc_channel.unary_unary(
+ "/google.cloud.aiplatform.v1.TensorboardService/ReadTensorboardSize",
+ request_serializer=tensorboard_service.ReadTensorboardSizeRequest.serialize,
+ response_deserializer=tensorboard_service.ReadTensorboardSizeResponse.deserialize,
+ )
+ return self._stubs["read_tensorboard_size"]
+
@property
def create_tensorboard_experiment(
self,
diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py
index 5e00d82b4d..e3e39bf923 100644
--- a/google/cloud/aiplatform_v1/types/__init__.py
+++ b/google/cloud/aiplatform_v1/types/__init__.py
@@ -57,6 +57,7 @@
from .dataset_service import (
CreateDatasetOperationMetadata,
CreateDatasetRequest,
+ CreateDatasetVersionOperationMetadata,
DataItemView,
DeleteDatasetRequest,
DeleteSavedQueryRequest,
@@ -76,6 +77,7 @@
ListDatasetsResponse,
ListSavedQueriesRequest,
ListSavedQueriesResponse,
+ RestoreDatasetVersionOperationMetadata,
SearchDataItemsRequest,
SearchDataItemsResponse,
UpdateDatasetRequest,
@@ -553,6 +555,19 @@
from .saved_query import (
SavedQuery,
)
+from .schedule import (
+ Schedule,
+)
+from .schedule_service import (
+ CreateScheduleRequest,
+ DeleteScheduleRequest,
+ GetScheduleRequest,
+ ListSchedulesRequest,
+ ListSchedulesResponse,
+ PauseScheduleRequest,
+ ResumeScheduleRequest,
+ UpdateScheduleRequest,
+)
from .service_networking import (
PrivateServiceConnectConfig,
)
@@ -624,6 +639,8 @@
ListTensorboardTimeSeriesResponse,
ReadTensorboardBlobDataRequest,
ReadTensorboardBlobDataResponse,
+ ReadTensorboardSizeRequest,
+ ReadTensorboardSizeResponse,
ReadTensorboardTimeSeriesDataRequest,
ReadTensorboardTimeSeriesDataResponse,
ReadTensorboardUsageRequest,
@@ -716,6 +733,7 @@
"ImportDataConfig",
"CreateDatasetOperationMetadata",
"CreateDatasetRequest",
+ "CreateDatasetVersionOperationMetadata",
"DataItemView",
"DeleteDatasetRequest",
"DeleteSavedQueryRequest",
@@ -735,6 +753,7 @@
"ListDatasetsResponse",
"ListSavedQueriesRequest",
"ListSavedQueriesResponse",
+ "RestoreDatasetVersionOperationMetadata",
"SearchDataItemsRequest",
"SearchDataItemsResponse",
"UpdateDatasetRequest",
@@ -1157,6 +1176,8 @@
"ListTensorboardTimeSeriesResponse",
"ReadTensorboardBlobDataRequest",
"ReadTensorboardBlobDataResponse",
+ "ReadTensorboardSizeRequest",
+ "ReadTensorboardSizeResponse",
"ReadTensorboardTimeSeriesDataRequest",
"ReadTensorboardTimeSeriesDataResponse",
"ReadTensorboardUsageRequest",
diff --git a/google/cloud/aiplatform_v1/types/context.py b/google/cloud/aiplatform_v1/types/context.py
index ee378e9bb9..e1e9bc48f3 100644
--- a/google/cloud/aiplatform_v1/types/context.py
+++ b/google/cloud/aiplatform_v1/types/context.py
@@ -36,8 +36,7 @@ class Context(proto.Message):
Attributes:
name (str):
- Output only. The resource name of the
- Context.
+ Immutable. The resource name of the Context.
display_name (str):
User provided display name of the Context.
May be up to 128 Unicode characters.
diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py
index 27a3035ece..8dea1bc45e 100644
--- a/google/cloud/aiplatform_v1/types/custom_job.py
+++ b/google/cloud/aiplatform_v1/types/custom_job.py
@@ -493,6 +493,10 @@ class Scheduling(proto.Message):
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
+ disable_retries (bool):
+ Optional. Indicates if the job should retry for internal
+ errors after the job starts running. If true, overrides
+ ``Scheduling.restart_job_on_worker_restart`` to false.
"""
timeout: duration_pb2.Duration = proto.Field(
@@ -504,6 +508,10 @@ class Scheduling(proto.Message):
proto.BOOL,
number=3,
)
+ disable_retries: bool = proto.Field(
+ proto.BOOL,
+ number=5,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1/types/dataset_service.py b/google/cloud/aiplatform_v1/types/dataset_service.py
index 7b99f0bdcb..ee46c05722 100644
--- a/google/cloud/aiplatform_v1/types/dataset_service.py
+++ b/google/cloud/aiplatform_v1/types/dataset_service.py
@@ -43,6 +43,8 @@
"ExportDataRequest",
"ExportDataResponse",
"ExportDataOperationMetadata",
+ "CreateDatasetVersionOperationMetadata",
+ "RestoreDatasetVersionOperationMetadata",
"ListDataItemsRequest",
"ListDataItemsResponse",
"SearchDataItemsRequest",
@@ -374,6 +376,38 @@ class ExportDataOperationMetadata(proto.Message):
)
+class CreateDatasetVersionOperationMetadata(proto.Message):
+ r"""Runtime operation information for
+ [DatasetService.CreateDatasetVersion][google.cloud.aiplatform.v1.DatasetService.CreateDatasetVersion].
+
+ Attributes:
+ generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata):
+ The common part of the operation metadata.
+ """
+
+ generic_metadata: operation.GenericOperationMetadata = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message=operation.GenericOperationMetadata,
+ )
+
+
+class RestoreDatasetVersionOperationMetadata(proto.Message):
+ r"""Runtime operation information for
+ [DatasetService.RestoreDatasetVersion][google.cloud.aiplatform.v1.DatasetService.RestoreDatasetVersion].
+
+ Attributes:
+ generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata):
+ The common part of the operation metadata.
+ """
+
+ generic_metadata: operation.GenericOperationMetadata = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message=operation.GenericOperationMetadata,
+ )
+
+
class ListDataItemsRequest(proto.Message):
r"""Request message for
[DatasetService.ListDataItems][google.cloud.aiplatform.v1.DatasetService.ListDataItems].
diff --git a/google/cloud/aiplatform_v1/types/index.py b/google/cloud/aiplatform_v1/types/index.py
index a2e23aaeda..8870e2e0eb 100644
--- a/google/cloud/aiplatform_v1/types/index.py
+++ b/google/cloud/aiplatform_v1/types/index.py
@@ -197,6 +197,7 @@ class IndexDatapoint(proto.Message):
used to perform "restricted searches" where
boolean rule are used to filter the subset of
the database eligible for matching. See:
+
https://cloud.google.com/vertex-ai/docs/matching-engine/filtering
crowding_tag (google.cloud.aiplatform_v1.types.IndexDatapoint.CrowdingTag):
Optional. CrowdingTag of the datapoint, the
diff --git a/google/cloud/aiplatform_v1/types/pipeline_job.py b/google/cloud/aiplatform_v1/types/pipeline_job.py
index d60e7cbede..de5e48e564 100644
--- a/google/cloud/aiplatform_v1/types/pipeline_job.py
+++ b/google/cloud/aiplatform_v1/types/pipeline_job.py
@@ -140,6 +140,10 @@ class PipelineJob(proto.Message):
if
[PipelineJob.template_uri][google.cloud.aiplatform.v1.PipelineJob.template_uri]
is from supported template registry.
+ schedule_name (str):
+ Output only. The schedule resource name.
+ Only returned if the Pipeline is created by
+ Schedule API.
"""
class RuntimeConfig(proto.Message):
@@ -325,6 +329,10 @@ class InputArtifact(proto.Message):
number=20,
message="PipelineTemplateMetadata",
)
+ schedule_name: str = proto.Field(
+ proto.STRING,
+ number=22,
+ )
class PipelineTemplateMetadata(proto.Message):
diff --git a/google/cloud/aiplatform_v1/types/publisher_model.py b/google/cloud/aiplatform_v1/types/publisher_model.py
index c9f60ceb9c..96e3f941dd 100644
--- a/google/cloud/aiplatform_v1/types/publisher_model.py
+++ b/google/cloud/aiplatform_v1/types/publisher_model.py
@@ -206,6 +206,9 @@ class CallToAction(proto.Message):
Optional. Open in Generation AI Studio.
request_access (google.cloud.aiplatform_v1.types.PublisherModel.CallToAction.RegionalResourceReferences):
Optional. Request for access.
+ open_evaluation_pipeline (google.cloud.aiplatform_v1.types.PublisherModel.CallToAction.RegionalResourceReferences):
+ Optional. Open evaluation pipeline of the
+ PublisherModel.
"""
class RegionalResourceReferences(proto.Message):
@@ -396,6 +399,11 @@ class Deploy(proto.Message):
message="PublisherModel.CallToAction.RegionalResourceReferences",
)
)
+ open_evaluation_pipeline: "PublisherModel.CallToAction.RegionalResourceReferences" = proto.Field(
+ proto.MESSAGE,
+ number=11,
+ message="PublisherModel.CallToAction.RegionalResourceReferences",
+ )
name: str = proto.Field(
proto.STRING,
diff --git a/google/cloud/aiplatform_v1/types/schedule.py b/google/cloud/aiplatform_v1/types/schedule.py
index 7bceaf70e1..190fcac960 100644
--- a/google/cloud/aiplatform_v1/types/schedule.py
+++ b/google/cloud/aiplatform_v1/types/schedule.py
@@ -58,8 +58,7 @@ class Schedule(proto.Message):
This field is a member of `oneof`_ ``request``.
name (str):
- Output only. The resource name of the
- Schedule.
+ Immutable. The resource name of the Schedule.
display_name (str):
Required. User provided name of the Schedule.
The name can be up to 128 characters long and
diff --git a/google/cloud/aiplatform_v1/types/schedule_service.py b/google/cloud/aiplatform_v1/types/schedule_service.py
index 77ac5bc58a..c88bf7ba5b 100644
--- a/google/cloud/aiplatform_v1/types/schedule_service.py
+++ b/google/cloud/aiplatform_v1/types/schedule_service.py
@@ -274,6 +274,7 @@ class UpdateScheduleRequest(proto.Message):
server. The following restrictions will be applied:
- The scheduled request type cannot be changed.
+ - The non-empty fields cannot be unset.
- The output_only fields will be ignored if specified.
update_mask (google.protobuf.field_mask_pb2.FieldMask):
Required. The update mask applies to the resource. See
diff --git a/google/cloud/aiplatform_v1/types/tensorboard_service.py b/google/cloud/aiplatform_v1/types/tensorboard_service.py
index f89f56238c..95e63ec176 100644
--- a/google/cloud/aiplatform_v1/types/tensorboard_service.py
+++ b/google/cloud/aiplatform_v1/types/tensorboard_service.py
@@ -43,6 +43,8 @@
"DeleteTensorboardRequest",
"ReadTensorboardUsageRequest",
"ReadTensorboardUsageResponse",
+ "ReadTensorboardSizeRequest",
+ "ReadTensorboardSizeResponse",
"CreateTensorboardExperimentRequest",
"GetTensorboardExperimentRequest",
"ListTensorboardExperimentsRequest",
@@ -327,6 +329,37 @@ class PerMonthUsageData(proto.Message):
)
+class ReadTensorboardSizeRequest(proto.Message):
+ r"""Request message for
+ [TensorboardService.ReadTensorboardSize][google.cloud.aiplatform.v1.TensorboardService.ReadTensorboardSize].
+
+ Attributes:
+ tensorboard (str):
+ Required. The name of the Tensorboard resource. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+ """
+
+ tensorboard: str = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+
+
+class ReadTensorboardSizeResponse(proto.Message):
+ r"""Response message for
+ [TensorboardService.ReadTensorboardSize][google.cloud.aiplatform.v1.TensorboardService.ReadTensorboardSize].
+
+ Attributes:
+ storage_size_byte (int):
+ Payload storage size for the TensorBoard
+ """
+
+ storage_size_byte: int = proto.Field(
+ proto.INT64,
+ number=1,
+ )
+
+
class CreateTensorboardExperimentRequest(proto.Message):
r"""Request message for
[TensorboardService.CreateTensorboardExperiment][google.cloud.aiplatform.v1.TensorboardService.CreateTensorboardExperiment].
diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py
index 7575620d9c..06c852162f 100644
--- a/google/cloud/aiplatform_v1beta1/__init__.py
+++ b/google/cloud/aiplatform_v1beta1/__init__.py
@@ -91,6 +91,7 @@
from .types.dataset import ImportDataConfig
from .types.dataset_service import CreateDatasetOperationMetadata
from .types.dataset_service import CreateDatasetRequest
+from .types.dataset_service import CreateDatasetVersionOperationMetadata
from .types.dataset_service import DataItemView
from .types.dataset_service import DeleteDatasetRequest
from .types.dataset_service import DeleteSavedQueryRequest
@@ -110,6 +111,7 @@
from .types.dataset_service import ListDatasetsResponse
from .types.dataset_service import ListSavedQueriesRequest
from .types.dataset_service import ListSavedQueriesResponse
+from .types.dataset_service import RestoreDatasetVersionOperationMetadata
from .types.dataset_service import SearchDataItemsRequest
from .types.dataset_service import SearchDataItemsResponse
from .types.dataset_service import UpdateDatasetRequest
@@ -331,6 +333,7 @@
from .types.machine_resources import DiskSpec
from .types.machine_resources import MachineSpec
from .types.machine_resources import NfsMount
+from .types.machine_resources import PersistentDiskSpec
from .types.machine_resources import ResourcesConsumed
from .types.manual_batch_tuning_parameters import ManualBatchTuningParameters
from .types.match_service import FindNeighborsRequest
@@ -495,6 +498,8 @@
from .types.pipeline_service import ListTrainingPipelinesRequest
from .types.pipeline_service import ListTrainingPipelinesResponse
from .types.pipeline_state import PipelineState
+from .types.prediction_service import CountTokensRequest
+from .types.prediction_service import CountTokensResponse
from .types.prediction_service import ExplainRequest
from .types.prediction_service import ExplainResponse
from .types.prediction_service import PredictRequest
@@ -701,6 +706,8 @@
"CopyModelOperationMetadata",
"CopyModelRequest",
"CopyModelResponse",
+ "CountTokensRequest",
+ "CountTokensResponse",
"CreateArtifactRequest",
"CreateBatchPredictionJobRequest",
"CreateContextRequest",
@@ -708,6 +715,7 @@
"CreateDataLabelingJobRequest",
"CreateDatasetOperationMetadata",
"CreateDatasetRequest",
+ "CreateDatasetVersionOperationMetadata",
"CreateDeploymentResourcePoolOperationMetadata",
"CreateDeploymentResourcePoolRequest",
"CreateEndpointOperationMetadata",
@@ -1043,6 +1051,7 @@
"NfsMount",
"PauseModelDeploymentMonitoringJobRequest",
"PauseScheduleRequest",
+ "PersistentDiskSpec",
"PersistentResource",
"PersistentResourceServiceClient",
"PipelineFailurePolicy",
@@ -1102,6 +1111,7 @@
"ResourceRuntime",
"ResourceRuntimeSpec",
"ResourcesConsumed",
+ "RestoreDatasetVersionOperationMetadata",
"ResumeModelDeploymentMonitoringJobRequest",
"ResumeScheduleRequest",
"SampleConfig",
diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json
index 47b9306b67..530f02ee15 100644
--- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json
+++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json
@@ -1910,6 +1910,11 @@
"grpc": {
"libraryClient": "PredictionServiceClient",
"rpcs": {
+ "CountTokens": {
+ "methods": [
+ "count_tokens"
+ ]
+ },
"Explain": {
"methods": [
"explain"
@@ -1935,6 +1940,11 @@
"grpc-async": {
"libraryClient": "PredictionServiceAsyncClient",
"rpcs": {
+ "CountTokens": {
+ "methods": [
+ "count_tokens"
+ ]
+ },
"Explain": {
"methods": [
"explain"
diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py
index dcedf0af71..b26e773716 100644
--- a/google/cloud/aiplatform_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.30.1" # {x-release-please-version}
+__version__ = "1.31.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py
index 7ff2e219de..9e03f37b72 100644
--- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py
+++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py
@@ -208,23 +208,18 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]:
@staticmethod
def dataset_path(
project: str,
- location: str,
dataset: str,
) -> str:
"""Returns a fully-qualified dataset string."""
- return "projects/{project}/locations/{location}/datasets/{dataset}".format(
+ return "projects/{project}/datasets/{dataset}".format(
project=project,
- location=location,
dataset=dataset,
)
@staticmethod
def parse_dataset_path(path: str) -> Dict[str, str]:
"""Parses a dataset path into its component segments."""
- m = re.match(
- r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$",
- path,
- )
+ m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path)
return m.groupdict() if m else {}
@staticmethod
@@ -252,18 +247,23 @@ def parse_dataset_path(path: str) -> Dict[str, str]:
@staticmethod
def dataset_path(
project: str,
+ location: str,
dataset: str,
) -> str:
"""Returns a fully-qualified dataset string."""
- return "projects/{project}/datasets/{dataset}".format(
+ return "projects/{project}/locations/{location}/datasets/{dataset}".format(
project=project,
+ location=location,
dataset=dataset,
)
@staticmethod
def parse_dataset_path(path: str) -> Dict[str, str]:
"""Parses a dataset path into its component segments."""
- m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path)
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$",
+ path,
+ )
return m.groupdict() if m else {}
@staticmethod
diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py
index c08d2951cf..ad59d81337 100644
--- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py
+++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py
@@ -260,7 +260,7 @@ async def create_persistent_resource(
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
- r"""Uploads a Model artifact into Vertex AI.
+ r"""Creates a PersistentResource.
.. code-block:: python
diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py
index 63a82afde3..61cb1f217e 100644
--- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py
+++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py
@@ -491,7 +491,7 @@ def create_persistent_resource(
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
- r"""Uploads a Model artifact into Vertex AI.
+ r"""Creates a PersistentResource.
.. code-block:: python
diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc.py
index 502d6f2b67..ab8a369deb 100644
--- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc.py
+++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc.py
@@ -261,7 +261,7 @@ def create_persistent_resource(
]:
r"""Return a callable for the create persistent resource method over gRPC.
- Uploads a Model artifact into Vertex AI.
+ Creates a PersistentResource.
Returns:
Callable[[~.CreatePersistentResourceRequest],
diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc_asyncio.py
index be18981831..765266b821 100644
--- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc_asyncio.py
+++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc_asyncio.py
@@ -266,7 +266,7 @@ def create_persistent_resource(
]:
r"""Return a callable for the create persistent resource method over gRPC.
- Uploads a Model artifact into Vertex AI.
+ Creates a PersistentResource.
Returns:
Callable[[~.CreatePersistentResourceRequest],
diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py
index 79fa9943e4..b185fb8616 100644
--- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py
+++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py
@@ -16,7 +16,6 @@
from collections import OrderedDict
import functools
import re
-import pkg_resources
from typing import (
Dict,
Mapping,
@@ -802,6 +801,125 @@ async def sample_explain():
# Done; return the response.
return response
+ async def count_tokens(
+ self,
+ request: Optional[Union[prediction_service.CountTokensRequest, dict]] = None,
+ *,
+ endpoint: Optional[str] = None,
+ instances: Optional[MutableSequence[struct_pb2.Value]] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: Union[float, object] = gapic_v1.method.DEFAULT,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> prediction_service.CountTokensResponse:
+ r"""Perform a token counting.
+
+ .. code-block:: python
+
+ # This snippet has been automatically generated and should be regarded as a
+ # code template only.
+ # It will require modifications to work:
+ # - It may require correct/in-range values for request initialization.
+ # - It may require specifying regional endpoints when creating the service
+ # client as shown in:
+ # https://googleapis.dev/python/google-api-core/latest/client_options.html
+ from google.cloud import aiplatform_v1beta1
+
+ async def sample_count_tokens():
+ # Create a client
+ client = aiplatform_v1beta1.PredictionServiceAsyncClient()
+
+ # Initialize request argument(s)
+ instances = aiplatform_v1beta1.Value()
+ instances.null_value = "NULL_VALUE"
+
+ request = aiplatform_v1beta1.CountTokensRequest(
+ endpoint="endpoint_value",
+ instances=instances,
+ )
+
+ # Make the request
+ response = await client.count_tokens(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Optional[Union[google.cloud.aiplatform_v1beta1.types.CountTokensRequest, dict]]):
+ The request object. Request message for
+ [PredictionService.CountTokens][google.cloud.aiplatform.v1beta1.PredictionService.CountTokens].
+ endpoint (:class:`str`):
+ Required. The name of the Endpoint requested to perform
+ token counting. Format:
+ ``projects/{project}/locations/{location}/endpoints/{endpoint}``
+
+ This corresponds to the ``endpoint`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ instances (:class:`MutableSequence[google.protobuf.struct_pb2.Value]`):
+ Required. The instances that are the
+ input to token counting call. Schema is
+ identical to the prediction schema of
+ the underlying model.
+
+ This corresponds to the ``instances`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1beta1.types.CountTokensResponse:
+ Response message for
+ [PredictionService.CountTokens][google.cloud.aiplatform.v1beta1.PredictionService.CountTokens].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([endpoint, instances])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = prediction_service.CountTokensRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if endpoint is not None:
+ request.endpoint = endpoint
+ if instances:
+ request.instances.extend(instances)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.count_tokens,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
async def list_operations(
self,
request: Optional[operations_pb2.ListOperationsRequest] = None,
diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py
index ba928c663c..083d575569 100644
--- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py
+++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py
@@ -16,7 +16,6 @@
from collections import OrderedDict
import os
import re
-import pkg_resources
from typing import (
Dict,
Mapping,
@@ -1049,6 +1048,125 @@ def sample_explain():
# Done; return the response.
return response
+ def count_tokens(
+ self,
+ request: Optional[Union[prediction_service.CountTokensRequest, dict]] = None,
+ *,
+ endpoint: Optional[str] = None,
+ instances: Optional[MutableSequence[struct_pb2.Value]] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: Union[float, object] = gapic_v1.method.DEFAULT,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> prediction_service.CountTokensResponse:
+ r"""Perform a token counting.
+
+ .. code-block:: python
+
+ # This snippet has been automatically generated and should be regarded as a
+ # code template only.
+ # It will require modifications to work:
+ # - It may require correct/in-range values for request initialization.
+ # - It may require specifying regional endpoints when creating the service
+ # client as shown in:
+ # https://googleapis.dev/python/google-api-core/latest/client_options.html
+ from google.cloud import aiplatform_v1beta1
+
+ def sample_count_tokens():
+ # Create a client
+ client = aiplatform_v1beta1.PredictionServiceClient()
+
+ # Initialize request argument(s)
+ instances = aiplatform_v1beta1.Value()
+ instances.null_value = "NULL_VALUE"
+
+ request = aiplatform_v1beta1.CountTokensRequest(
+ endpoint="endpoint_value",
+ instances=instances,
+ )
+
+ # Make the request
+ response = client.count_tokens(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1beta1.types.CountTokensRequest, dict]):
+ The request object. Request message for
+ [PredictionService.CountTokens][google.cloud.aiplatform.v1beta1.PredictionService.CountTokens].
+ endpoint (str):
+ Required. The name of the Endpoint requested to perform
+ token counting. Format:
+ ``projects/{project}/locations/{location}/endpoints/{endpoint}``
+
+ This corresponds to the ``endpoint`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ instances (MutableSequence[google.protobuf.struct_pb2.Value]):
+ Required. The instances that are the
+ input to token counting call. Schema is
+ identical to the prediction schema of
+ the underlying model.
+
+ This corresponds to the ``instances`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1beta1.types.CountTokensResponse:
+ Response message for
+ [PredictionService.CountTokens][google.cloud.aiplatform.v1beta1.PredictionService.CountTokens].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([endpoint, instances])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a prediction_service.CountTokensRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, prediction_service.CountTokensRequest):
+ request = prediction_service.CountTokensRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if endpoint is not None:
+ request.endpoint = endpoint
+ if instances is not None:
+ request.instances.extend(instances)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.count_tokens]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
def __enter__(self) -> "PredictionServiceClient":
return self
diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py
index 44ab49ffb0..2ff643a45f 100644
--- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py
+++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py
@@ -148,6 +148,11 @@ def _prep_wrapped_messages(self, client_info):
default_timeout=5.0,
client_info=client_info,
),
+ self.count_tokens: gapic_v1.method.wrap_method(
+ self.count_tokens,
+ default_timeout=None,
+ client_info=client_info,
+ ),
}
def close(self):
@@ -204,6 +209,18 @@ def explain(
]:
raise NotImplementedError()
+ @property
+ def count_tokens(
+ self,
+ ) -> Callable[
+ [prediction_service.CountTokensRequest],
+ Union[
+ prediction_service.CountTokensResponse,
+ Awaitable[prediction_service.CountTokensResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
@property
def list_operations(
self,
diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py
index a6017621e7..547cdd269e 100644
--- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py
+++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py
@@ -366,6 +366,34 @@ def explain(
)
return self._stubs["explain"]
+ @property
+ def count_tokens(
+ self,
+ ) -> Callable[
+ [prediction_service.CountTokensRequest], prediction_service.CountTokensResponse
+ ]:
+ r"""Return a callable for the count tokens method over gRPC.
+
+ Perform a token counting.
+
+ Returns:
+ Callable[[~.CountTokensRequest],
+ ~.CountTokensResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "count_tokens" not in self._stubs:
+ self._stubs["count_tokens"] = self.grpc_channel.unary_unary(
+ "/google.cloud.aiplatform.v1beta1.PredictionService/CountTokens",
+ request_serializer=prediction_service.CountTokensRequest.serialize,
+ response_deserializer=prediction_service.CountTokensResponse.deserialize,
+ )
+ return self._stubs["count_tokens"]
+
def close(self):
self.grpc_channel.close()
diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py
index 841c3c7898..bf44a1d826 100644
--- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py
+++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py
@@ -373,6 +373,35 @@ def explain(
)
return self._stubs["explain"]
+ @property
+ def count_tokens(
+ self,
+ ) -> Callable[
+ [prediction_service.CountTokensRequest],
+ Awaitable[prediction_service.CountTokensResponse],
+ ]:
+ r"""Return a callable for the count tokens method over gRPC.
+
+ Perform a token counting.
+
+ Returns:
+ Callable[[~.CountTokensRequest],
+ Awaitable[~.CountTokensResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "count_tokens" not in self._stubs:
+ self._stubs["count_tokens"] = self.grpc_channel.unary_unary(
+ "/google.cloud.aiplatform.v1beta1.PredictionService/CountTokens",
+ request_serializer=prediction_service.CountTokensRequest.serialize,
+ response_deserializer=prediction_service.CountTokensResponse.deserialize,
+ )
+ return self._stubs["count_tokens"]
+
def close(self):
return self.grpc_channel.close()
diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py
index bf9658807b..e2c7334df0 100644
--- a/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py
+++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py
@@ -976,6 +976,7 @@ async def sample_update_schedule():
the server. The following restrictions will be applied:
- The scheduled request type cannot be changed.
+ - The non-empty fields cannot be unset.
- The output_only fields will be ignored if specified.
This corresponds to the ``schedule`` field
diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py
index a1267630c5..78fdc53aaa 100644
--- a/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py
+++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py
@@ -1324,6 +1324,7 @@ def sample_update_schedule():
the server. The following restrictions will be applied:
- The scheduled request type cannot be changed.
+ - The non-empty fields cannot be unset.
- The output_only fields will be ignored if specified.
This corresponds to the ``schedule`` field
diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py
index cd77aed5bc..6340923f22 100644
--- a/google/cloud/aiplatform_v1beta1/types/__init__.py
+++ b/google/cloud/aiplatform_v1beta1/types/__init__.py
@@ -57,6 +57,7 @@
from .dataset_service import (
CreateDatasetOperationMetadata,
CreateDatasetRequest,
+ CreateDatasetVersionOperationMetadata,
DataItemView,
DeleteDatasetRequest,
DeleteSavedQueryRequest,
@@ -76,6 +77,7 @@
ListDatasetsResponse,
ListSavedQueriesRequest,
ListSavedQueriesResponse,
+ RestoreDatasetVersionOperationMetadata,
SearchDataItemsRequest,
SearchDataItemsResponse,
UpdateDatasetRequest,
@@ -352,6 +354,7 @@
DiskSpec,
MachineSpec,
NfsMount,
+ PersistentDiskSpec,
ResourcesConsumed,
)
from .manual_batch_tuning_parameters import (
@@ -548,6 +551,8 @@
ListTrainingPipelinesResponse,
)
from .prediction_service import (
+ CountTokensRequest,
+ CountTokensResponse,
ExplainRequest,
ExplainResponse,
PredictRequest,
@@ -740,6 +745,7 @@
"ImportDataConfig",
"CreateDatasetOperationMetadata",
"CreateDatasetRequest",
+ "CreateDatasetVersionOperationMetadata",
"DataItemView",
"DeleteDatasetRequest",
"DeleteSavedQueryRequest",
@@ -759,6 +765,7 @@
"ListDatasetsResponse",
"ListSavedQueriesRequest",
"ListSavedQueriesResponse",
+ "RestoreDatasetVersionOperationMetadata",
"SearchDataItemsRequest",
"SearchDataItemsResponse",
"UpdateDatasetRequest",
@@ -976,6 +983,7 @@
"DiskSpec",
"MachineSpec",
"NfsMount",
+ "PersistentDiskSpec",
"ResourcesConsumed",
"ManualBatchTuningParameters",
"FindNeighborsRequest",
@@ -1132,6 +1140,8 @@
"ListTrainingPipelinesRequest",
"ListTrainingPipelinesResponse",
"PipelineState",
+ "CountTokensRequest",
+ "CountTokensResponse",
"ExplainRequest",
"ExplainResponse",
"PredictRequest",
diff --git a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py
index 9abc3bc1da..f8ba96e776 100644
--- a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py
+++ b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py
@@ -51,12 +51,16 @@ class AcceleratorType(proto.Enum):
Nvidia A100 80GB GPU.
NVIDIA_L4 (11):
Nvidia L4 GPU.
+ NVIDIA_H100_80GB (13):
+ Nvidia H100 80Gb GPU.
TPU_V2 (6):
TPU v2.
TPU_V3 (7):
TPU v3.
TPU_V4_POD (10):
TPU v4.
+ TPU_V5_LITEPOD (12):
+ TPU v5.
"""
ACCELERATOR_TYPE_UNSPECIFIED = 0
NVIDIA_TESLA_K80 = 1
@@ -67,9 +71,11 @@ class AcceleratorType(proto.Enum):
NVIDIA_TESLA_A100 = 8
NVIDIA_A100_80GB = 9
NVIDIA_L4 = 11
+ NVIDIA_H100_80GB = 13
TPU_V2 = 6
TPU_V3 = 7
TPU_V4_POD = 10
+ TPU_V5_LITEPOD = 12
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1beta1/types/context.py b/google/cloud/aiplatform_v1beta1/types/context.py
index a73cfa7bf3..6e9f5f611a 100644
--- a/google/cloud/aiplatform_v1beta1/types/context.py
+++ b/google/cloud/aiplatform_v1beta1/types/context.py
@@ -36,8 +36,7 @@ class Context(proto.Message):
Attributes:
name (str):
- Output only. The resource name of the
- Context.
+ Immutable. The resource name of the Context.
display_name (str):
User provided display name of the Context.
May be up to 128 Unicode characters.
diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py
index 1e9a95c51c..4662c176a1 100644
--- a/google/cloud/aiplatform_v1beta1/types/custom_job.py
+++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py
@@ -507,6 +507,10 @@ class Scheduling(proto.Message):
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
+ disable_retries (bool):
+ Optional. Indicates if the job should retry for internal
+ errors after the job starts running. If true, overrides
+ ``Scheduling.restart_job_on_worker_restart`` to false.
"""
timeout: duration_pb2.Duration = proto.Field(
@@ -518,6 +522,10 @@ class Scheduling(proto.Message):
proto.BOOL,
number=3,
)
+ disable_retries: bool = proto.Field(
+ proto.BOOL,
+ number=5,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py
index 8e20c90529..5d88231aef 100644
--- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py
+++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py
@@ -43,6 +43,8 @@
"ExportDataRequest",
"ExportDataResponse",
"ExportDataOperationMetadata",
+ "CreateDatasetVersionOperationMetadata",
+ "RestoreDatasetVersionOperationMetadata",
"ListDataItemsRequest",
"ListDataItemsResponse",
"SearchDataItemsRequest",
@@ -374,6 +376,38 @@ class ExportDataOperationMetadata(proto.Message):
)
+class CreateDatasetVersionOperationMetadata(proto.Message):
+ r"""Runtime operation information for
+ [DatasetService.CreateDatasetVersion][google.cloud.aiplatform.v1beta1.DatasetService.CreateDatasetVersion].
+
+ Attributes:
+ generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata):
+ The common part of the operation metadata.
+ """
+
+ generic_metadata: operation.GenericOperationMetadata = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message=operation.GenericOperationMetadata,
+ )
+
+
+class RestoreDatasetVersionOperationMetadata(proto.Message):
+ r"""Runtime operation information for
+ [DatasetService.RestoreDatasetVersion][google.cloud.aiplatform.v1beta1.DatasetService.RestoreDatasetVersion].
+
+ Attributes:
+ generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata):
+ The common part of the operation metadata.
+ """
+
+ generic_metadata: operation.GenericOperationMetadata = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message=operation.GenericOperationMetadata,
+ )
+
+
class ListDataItemsRequest(proto.Message):
r"""Request message for
[DatasetService.ListDataItems][google.cloud.aiplatform.v1beta1.DatasetService.ListDataItems].
diff --git a/google/cloud/aiplatform_v1beta1/types/deployment_resource_pool.py b/google/cloud/aiplatform_v1beta1/types/deployment_resource_pool.py
index 1d4201478b..19fa2f108c 100644
--- a/google/cloud/aiplatform_v1beta1/types/deployment_resource_pool.py
+++ b/google/cloud/aiplatform_v1beta1/types/deployment_resource_pool.py
@@ -38,8 +38,8 @@ class DeploymentResourcePool(proto.Message):
Attributes:
name (str):
- Output only. The resource name of the
- DeploymentResourcePool. Format:
+ Immutable. The resource name of the DeploymentResourcePool.
+ Format:
``projects/{project}/locations/{location}/deploymentResourcePools/{deployment_resource_pool}``
dedicated_resources (google.cloud.aiplatform_v1beta1.types.DedicatedResources):
Required. The underlying DedicatedResources
diff --git a/google/cloud/aiplatform_v1beta1/types/index.py b/google/cloud/aiplatform_v1beta1/types/index.py
index e4bf685cbc..5cbc636054 100644
--- a/google/cloud/aiplatform_v1beta1/types/index.py
+++ b/google/cloud/aiplatform_v1beta1/types/index.py
@@ -197,6 +197,7 @@ class IndexDatapoint(proto.Message):
used to perform "restricted searches" where
boolean rule are used to filter the subset of
the database eligible for matching. See:
+
https://cloud.google.com/vertex-ai/docs/matching-engine/filtering
crowding_tag (google.cloud.aiplatform_v1beta1.types.IndexDatapoint.CrowdingTag):
Optional. CrowdingTag of the datapoint, the
diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py
index 04f9d77452..d5d2bf7de1 100644
--- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py
+++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py
@@ -33,6 +33,7 @@
"BatchDedicatedResources",
"ResourcesConsumed",
"DiskSpec",
+ "PersistentDiskSpec",
"NfsMount",
"AutoscalingMetricSpec",
},
@@ -292,6 +293,33 @@ class DiskSpec(proto.Message):
)
+class PersistentDiskSpec(proto.Message):
+ r"""Represents the spec of [persistent
+ disk][https://cloud.google.com/compute/docs/disks/persistent-disks]
+ options.
+
+ Attributes:
+ disk_type (str):
+ Type of the disk (default is "pd-standard").
+ Valid values: "pd-ssd" (Persistent Disk Solid
+ State Drive) "pd-standard" (Persistent Disk Hard
+ Disk Drive) "pd-balanced" (Balanced Persistent
+ Disk)
+ "pd-extreme" (Extreme Persistent Disk)
+ disk_size_gb (int):
+ Size in GB of the disk (default is 100GB).
+ """
+
+ disk_type: str = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ disk_size_gb: int = proto.Field(
+ proto.INT64,
+ number=2,
+ )
+
+
class NfsMount(proto.Message):
r"""Represents a mount configuration for Network File System
(NFS) to mount.
diff --git a/google/cloud/aiplatform_v1beta1/types/persistent_resource.py b/google/cloud/aiplatform_v1beta1/types/persistent_resource.py
index 7f7ba49e2d..4425c86407 100644
--- a/google/cloud/aiplatform_v1beta1/types/persistent_resource.py
+++ b/google/cloud/aiplatform_v1beta1/types/persistent_resource.py
@@ -223,7 +223,7 @@ class ResourcePool(proto.Message):
Attributes:
id (str):
- Optional. The unique ID in a
+ Immutable. The unique ID in a
PersistentResource to refer the this resource
pool. User can specify it if need to use it,
otherwise we will generate it automatically.
@@ -238,10 +238,6 @@ class ResourcePool(proto.Message):
disk_spec (google.cloud.aiplatform_v1beta1.types.DiskSpec):
Optional. Disk spec for the machine in this
node pool.
- idle_replica_count (int):
- Output only. The number of machines currently not in use by
- training jobs for this resource pool. Deprecated. Use
- ``used_replica_count`` instead.
used_replica_count (int):
Output only. The number of machines currently in use by
training jobs for this resource pool. Will replace
@@ -301,10 +297,6 @@ class AutoscalingSpec(proto.Message):
number=4,
message=machine_resources.DiskSpec,
)
- idle_replica_count: int = proto.Field(
- proto.INT64,
- number=5,
- )
used_replica_count: int = proto.Field(
proto.INT64,
number=6,
@@ -328,7 +320,7 @@ class ResourceRuntimeSpec(proto.Message):
Optional. Configure the use of workload
identity on the PersistentResource
ray_spec (google.cloud.aiplatform_v1beta1.types.RaySpec):
- Ray cluster configuration.
+ Optional. Ray cluster configuration.
Required when creating a dedicated RayCluster on
the PersistentResource.
"""
diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_job.py b/google/cloud/aiplatform_v1beta1/types/pipeline_job.py
index b8f82af8dc..67aa7ed124 100644
--- a/google/cloud/aiplatform_v1beta1/types/pipeline_job.py
+++ b/google/cloud/aiplatform_v1beta1/types/pipeline_job.py
@@ -140,6 +140,10 @@ class PipelineJob(proto.Message):
if
[PipelineJob.template_uri][google.cloud.aiplatform.v1beta1.PipelineJob.template_uri]
is from supported template registry.
+ schedule_name (str):
+ Output only. The schedule resource name.
+ Only returned if the Pipeline is created by
+ Schedule API.
"""
class RuntimeConfig(proto.Message):
@@ -325,6 +329,10 @@ class InputArtifact(proto.Message):
number=20,
message="PipelineTemplateMetadata",
)
+ schedule_name: str = proto.Field(
+ proto.STRING,
+ number=22,
+ )
class PipelineTemplateMetadata(proto.Message):
diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py
index ee907296da..fd6b807e14 100644
--- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py
+++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py
@@ -35,6 +35,8 @@
"StreamingPredictResponse",
"ExplainRequest",
"ExplainResponse",
+ "CountTokensRequest",
+ "CountTokensResponse",
},
)
@@ -356,4 +358,53 @@ class ExplainResponse(proto.Message):
)
+class CountTokensRequest(proto.Message):
+ r"""Request message for
+ [PredictionService.CountTokens][google.cloud.aiplatform.v1beta1.PredictionService.CountTokens].
+
+ Attributes:
+ endpoint (str):
+ Required. The name of the Endpoint requested to perform
+ token counting. Format:
+ ``projects/{project}/locations/{location}/endpoints/{endpoint}``
+ instances (MutableSequence[google.protobuf.struct_pb2.Value]):
+ Required. The instances that are the input to
+ token counting call. Schema is identical to the
+ prediction schema of the underlying model.
+ """
+
+ endpoint: str = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ instances: MutableSequence[struct_pb2.Value] = proto.RepeatedField(
+ proto.MESSAGE,
+ number=2,
+ message=struct_pb2.Value,
+ )
+
+
+class CountTokensResponse(proto.Message):
+ r"""Response message for
+ [PredictionService.CountTokens][google.cloud.aiplatform.v1beta1.PredictionService.CountTokens].
+
+ Attributes:
+ total_tokens (int):
+ The total number of tokens counted across all
+ instances from the request.
+ total_billable_characters (int):
+ The total number of billable characters
+ counted across all instances from the request.
+ """
+
+ total_tokens: int = proto.Field(
+ proto.INT32,
+ number=1,
+ )
+ total_billable_characters: int = proto.Field(
+ proto.INT32,
+ number=2,
+ )
+
+
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1beta1/types/publisher_model.py b/google/cloud/aiplatform_v1beta1/types/publisher_model.py
index 66fc8f2176..a7d50bd8b9 100644
--- a/google/cloud/aiplatform_v1beta1/types/publisher_model.py
+++ b/google/cloud/aiplatform_v1beta1/types/publisher_model.py
@@ -47,6 +47,11 @@ class PublisherModel(proto.Message):
open_source_category (google.cloud.aiplatform_v1beta1.types.PublisherModel.OpenSourceCategory):
Required. Indicates the open source category
of the publisher model.
+ parent (google.cloud.aiplatform_v1beta1.types.PublisherModel.Parent):
+ Optional. The parent that this model was
+ customized from. E.g., Vision API, Natural
+ Language API, LaMDA, T5, etc. Foundation models
+ don't have parents.
supported_actions (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction):
Optional. Supported call-to-action options.
frameworks (MutableSequence[str]):
@@ -157,6 +162,29 @@ class ResourceReference(proto.Message):
oneof="reference",
)
+ class Parent(proto.Message):
+ r"""The information about the parent of a model.
+
+ Attributes:
+ display_name (str):
+ Required. The display name of the parent.
+ E.g., LaMDA, T5, Vision API, Natural Language
+ API.
+ reference (google.cloud.aiplatform_v1beta1.types.PublisherModel.ResourceReference):
+ Optional. The Google Cloud resource name or
+ the URI reference.
+ """
+
+ display_name: str = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ reference: "PublisherModel.ResourceReference" = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ message="PublisherModel.ResourceReference",
+ )
+
class Documentation(proto.Message):
r"""A named piece of documentation.
@@ -206,6 +234,9 @@ class CallToAction(proto.Message):
Optional. Open in Generation AI Studio.
request_access (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.RegionalResourceReferences):
Optional. Request for access.
+ open_evaluation_pipeline (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.RegionalResourceReferences):
+ Optional. Open evaluation pipeline of the
+ PublisherModel.
"""
class RegionalResourceReferences(proto.Message):
@@ -396,6 +427,11 @@ class Deploy(proto.Message):
message="PublisherModel.CallToAction.RegionalResourceReferences",
)
)
+ open_evaluation_pipeline: "PublisherModel.CallToAction.RegionalResourceReferences" = proto.Field(
+ proto.MESSAGE,
+ number=11,
+ message="PublisherModel.CallToAction.RegionalResourceReferences",
+ )
name: str = proto.Field(
proto.STRING,
@@ -410,6 +446,11 @@ class Deploy(proto.Message):
number=7,
enum=OpenSourceCategory,
)
+ parent: Parent = proto.Field(
+ proto.MESSAGE,
+ number=14,
+ message=Parent,
+ )
supported_actions: CallToAction = proto.Field(
proto.MESSAGE,
number=19,
diff --git a/google/cloud/aiplatform_v1beta1/types/schedule.py b/google/cloud/aiplatform_v1beta1/types/schedule.py
index d022177bb3..b9c1026050 100644
--- a/google/cloud/aiplatform_v1beta1/types/schedule.py
+++ b/google/cloud/aiplatform_v1beta1/types/schedule.py
@@ -58,8 +58,7 @@ class Schedule(proto.Message):
This field is a member of `oneof`_ ``request``.
name (str):
- Output only. The resource name of the
- Schedule.
+ Immutable. The resource name of the Schedule.
display_name (str):
Required. User provided name of the Schedule.
The name can be up to 128 characters long and
diff --git a/google/cloud/aiplatform_v1beta1/types/schedule_service.py b/google/cloud/aiplatform_v1beta1/types/schedule_service.py
index e77b11fd75..0ef393a3db 100644
--- a/google/cloud/aiplatform_v1beta1/types/schedule_service.py
+++ b/google/cloud/aiplatform_v1beta1/types/schedule_service.py
@@ -274,6 +274,7 @@ class UpdateScheduleRequest(proto.Message):
server. The following restrictions will be applied:
- The scheduled request type cannot be changed.
+ - The non-empty fields cannot be unset.
- The output_only fields will be ignored if specified.
update_mask (google.protobuf.field_mask_pb2.FieldMask):
Required. The update mask applies to the resource. See
diff --git a/noxfile.py b/noxfile.py
index 87ec893504..bc9dbabd5e 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -32,7 +32,7 @@
DEFAULT_PYTHON_VERSION = "3.8"
-UNIT_TEST_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10"]
+UNIT_TEST_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11"]
UNIT_TEST_STANDARD_DEPENDENCIES = [
"mock",
"asyncmock",
diff --git a/owlbot.py b/owlbot.py
index 5f173be611..263dc6ee39 100644
--- a/owlbot.py
+++ b/owlbot.py
@@ -88,14 +88,14 @@
# only run post processor when there are changes to the generated code
if has_generator_updates:
-# ----------------------------------------------------------------------------
-# Add templated files
-# ----------------------------------------------------------------------------
+ # ----------------------------------------------------------------------------
+ # Add templated files
+ # ----------------------------------------------------------------------------
templated_files = common.py_library(
cov_level=98,
system_test_python_versions=["3.8"],
- unit_test_python_versions=["3.7", "3.8", "3.9", "3.10"],
+ unit_test_python_versions=["3.7", "3.8", "3.9", "3.10", "3.11"],
unit_test_extras=["testing"],
system_test_extras=["testing"],
microgenerator=True,
@@ -130,13 +130,13 @@
s.replace(
".kokoro/samples/python3.*/common.cfg",
"""env_vars: \{
- key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
- value: "python-docs-samples-tests-.*?"
-\}""",
- """env_vars: {
- key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
- value: "ucaip-sample-tests"
-}""",
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "python-docs-samples-tests-.*?"
+ \}""",
+ """env_vars: {
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "ucaip-sample-tests"
+ }""",
)
s.replace(
diff --git a/samples/generated_samples/aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_async.py b/samples/generated_samples/aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_async.py
new file mode 100644
index 0000000000..7b8b2f51e0
--- /dev/null
+++ b/samples/generated_samples/aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_async.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Generated code. DO NOT EDIT!
+#
+# Snippet for ReadTensorboardSize
+# NOTE: This snippet has been automatically generated for illustrative purposes only.
+# It may require modifications to work in your environment.
+
+# To install the latest published package dependency, execute the following:
+# python3 -m pip install google-cloud-aiplatform
+
+
+# [START aiplatform_v1_generated_TensorboardService_ReadTensorboardSize_async]
+# This snippet has been automatically generated and should be regarded as a
+# code template only.
+# It will require modifications to work:
+# - It may require correct/in-range values for request initialization.
+# - It may require specifying regional endpoints when creating the service
+# client as shown in:
+# https://googleapis.dev/python/google-api-core/latest/client_options.html
+from google.cloud import aiplatform_v1
+
+
+async def sample_read_tensorboard_size():
+ # Create a client
+ client = aiplatform_v1.TensorboardServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ReadTensorboardSizeRequest(
+ tensorboard="tensorboard_value",
+ )
+
+ # Make the request
+ response = await client.read_tensorboard_size(request=request)
+
+ # Handle the response
+ print(response)
+
+# [END aiplatform_v1_generated_TensorboardService_ReadTensorboardSize_async]
diff --git a/samples/generated_samples/aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_sync.py b/samples/generated_samples/aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_sync.py
new file mode 100644
index 0000000000..d9f6e5b734
--- /dev/null
+++ b/samples/generated_samples/aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_sync.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Generated code. DO NOT EDIT!
+#
+# Snippet for ReadTensorboardSize
+# NOTE: This snippet has been automatically generated for illustrative purposes only.
+# It may require modifications to work in your environment.
+
+# To install the latest published package dependency, execute the following:
+# python3 -m pip install google-cloud-aiplatform
+
+
+# [START aiplatform_v1_generated_TensorboardService_ReadTensorboardSize_sync]
+# This snippet has been automatically generated and should be regarded as a
+# code template only.
+# It will require modifications to work:
+# - It may require correct/in-range values for request initialization.
+# - It may require specifying regional endpoints when creating the service
+# client as shown in:
+# https://googleapis.dev/python/google-api-core/latest/client_options.html
+from google.cloud import aiplatform_v1
+
+
+def sample_read_tensorboard_size():
+ # Create a client
+ client = aiplatform_v1.TensorboardServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ReadTensorboardSizeRequest(
+ tensorboard="tensorboard_value",
+ )
+
+ # Make the request
+ response = client.read_tensorboard_size(request=request)
+
+ # Handle the response
+ print(response)
+
+# [END aiplatform_v1_generated_TensorboardService_ReadTensorboardSize_sync]
diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_count_tokens_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_count_tokens_async.py
new file mode 100644
index 0000000000..9b28517f87
--- /dev/null
+++ b/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_count_tokens_async.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Generated code. DO NOT EDIT!
+#
+# Snippet for CountTokens
+# NOTE: This snippet has been automatically generated for illustrative purposes only.
+# It may require modifications to work in your environment.
+
+# To install the latest published package dependency, execute the following:
+# python3 -m pip install google-cloud-aiplatform
+
+
+# [START aiplatform_v1beta1_generated_PredictionService_CountTokens_async]
+# This snippet has been automatically generated and should be regarded as a
+# code template only.
+# It will require modifications to work:
+# - It may require correct/in-range values for request initialization.
+# - It may require specifying regional endpoints when creating the service
+# client as shown in:
+# https://googleapis.dev/python/google-api-core/latest/client_options.html
+from google.cloud import aiplatform_v1beta1
+
+
+async def sample_count_tokens():
+ # Create a client
+ client = aiplatform_v1beta1.PredictionServiceAsyncClient()
+
+ # Initialize request argument(s)
+ instances = aiplatform_v1beta1.Value()
+ instances.null_value = "NULL_VALUE"
+
+ request = aiplatform_v1beta1.CountTokensRequest(
+ endpoint="endpoint_value",
+ instances=instances,
+ )
+
+ # Make the request
+ response = await client.count_tokens(request=request)
+
+ # Handle the response
+ print(response)
+
+# [END aiplatform_v1beta1_generated_PredictionService_CountTokens_async]
diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_count_tokens_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_count_tokens_sync.py
new file mode 100644
index 0000000000..4a08a5dce8
--- /dev/null
+++ b/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_count_tokens_sync.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Generated code. DO NOT EDIT!
+#
+# Snippet for CountTokens
+# NOTE: This snippet has been automatically generated for illustrative purposes only.
+# It may require modifications to work in your environment.
+
+# To install the latest published package dependency, execute the following:
+# python3 -m pip install google-cloud-aiplatform
+
+
+# [START aiplatform_v1beta1_generated_PredictionService_CountTokens_sync]
+# This snippet has been automatically generated and should be regarded as a
+# code template only.
+# It will require modifications to work:
+# - It may require correct/in-range values for request initialization.
+# - It may require specifying regional endpoints when creating the service
+# client as shown in:
+# https://googleapis.dev/python/google-api-core/latest/client_options.html
+from google.cloud import aiplatform_v1beta1
+
+
+def sample_count_tokens():
+ # Create a client
+ client = aiplatform_v1beta1.PredictionServiceClient()
+
+ # Initialize request argument(s)
+ instances = aiplatform_v1beta1.Value()
+ instances.null_value = "NULL_VALUE"
+
+ request = aiplatform_v1beta1.CountTokensRequest(
+ endpoint="endpoint_value",
+ instances=instances,
+ )
+
+ # Make the request
+ response = client.count_tokens(request=request)
+
+ # Handle the response
+ print(response)
+
+# [END aiplatform_v1beta1_generated_PredictionService_CountTokens_sync]
diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
index 01c34dc7a3..3d400f3251 100644
--- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
+++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
@@ -8,7 +8,7 @@
],
"language": "PYTHON",
"name": "google-cloud-aiplatform",
- "version": "1.30.1"
+ "version": "1.31.0"
},
"snippets": [
{
@@ -32342,6 +32342,167 @@
],
"title": "aiplatform_v1_generated_tensorboard_service_read_tensorboard_blob_data_sync.py"
},
+ {
+ "canonical": true,
+ "clientMethod": {
+ "async": true,
+ "client": {
+ "fullName": "google.cloud.aiplatform_v1.TensorboardServiceAsyncClient",
+ "shortName": "TensorboardServiceAsyncClient"
+ },
+ "fullName": "google.cloud.aiplatform_v1.TensorboardServiceAsyncClient.read_tensorboard_size",
+ "method": {
+ "fullName": "google.cloud.aiplatform.v1.TensorboardService.ReadTensorboardSize",
+ "service": {
+ "fullName": "google.cloud.aiplatform.v1.TensorboardService",
+ "shortName": "TensorboardService"
+ },
+ "shortName": "ReadTensorboardSize"
+ },
+ "parameters": [
+ {
+ "name": "request",
+ "type": "google.cloud.aiplatform_v1.types.ReadTensorboardSizeRequest"
+ },
+ {
+ "name": "tensorboard",
+ "type": "str"
+ },
+ {
+ "name": "retry",
+ "type": "google.api_core.retry.Retry"
+ },
+ {
+ "name": "timeout",
+ "type": "float"
+ },
+ {
+ "name": "metadata",
+ "type": "Sequence[Tuple[str, str]"
+ }
+ ],
+ "resultType": "google.cloud.aiplatform_v1.types.ReadTensorboardSizeResponse",
+ "shortName": "read_tensorboard_size"
+ },
+ "description": "Sample for ReadTensorboardSize",
+ "file": "aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_async.py",
+ "language": "PYTHON",
+ "origin": "API_DEFINITION",
+ "regionTag": "aiplatform_v1_generated_TensorboardService_ReadTensorboardSize_async",
+ "segments": [
+ {
+ "end": 51,
+ "start": 27,
+ "type": "FULL"
+ },
+ {
+ "end": 51,
+ "start": 27,
+ "type": "SHORT"
+ },
+ {
+ "end": 40,
+ "start": 38,
+ "type": "CLIENT_INITIALIZATION"
+ },
+ {
+ "end": 45,
+ "start": 41,
+ "type": "REQUEST_INITIALIZATION"
+ },
+ {
+ "end": 48,
+ "start": 46,
+ "type": "REQUEST_EXECUTION"
+ },
+ {
+ "end": 52,
+ "start": 49,
+ "type": "RESPONSE_HANDLING"
+ }
+ ],
+ "title": "aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_async.py"
+ },
+ {
+ "canonical": true,
+ "clientMethod": {
+ "client": {
+ "fullName": "google.cloud.aiplatform_v1.TensorboardServiceClient",
+ "shortName": "TensorboardServiceClient"
+ },
+ "fullName": "google.cloud.aiplatform_v1.TensorboardServiceClient.read_tensorboard_size",
+ "method": {
+ "fullName": "google.cloud.aiplatform.v1.TensorboardService.ReadTensorboardSize",
+ "service": {
+ "fullName": "google.cloud.aiplatform.v1.TensorboardService",
+ "shortName": "TensorboardService"
+ },
+ "shortName": "ReadTensorboardSize"
+ },
+ "parameters": [
+ {
+ "name": "request",
+ "type": "google.cloud.aiplatform_v1.types.ReadTensorboardSizeRequest"
+ },
+ {
+ "name": "tensorboard",
+ "type": "str"
+ },
+ {
+ "name": "retry",
+ "type": "google.api_core.retry.Retry"
+ },
+ {
+ "name": "timeout",
+ "type": "float"
+ },
+ {
+ "name": "metadata",
+ "type": "Sequence[Tuple[str, str]"
+ }
+ ],
+ "resultType": "google.cloud.aiplatform_v1.types.ReadTensorboardSizeResponse",
+ "shortName": "read_tensorboard_size"
+ },
+ "description": "Sample for ReadTensorboardSize",
+ "file": "aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_sync.py",
+ "language": "PYTHON",
+ "origin": "API_DEFINITION",
+ "regionTag": "aiplatform_v1_generated_TensorboardService_ReadTensorboardSize_sync",
+ "segments": [
+ {
+ "end": 51,
+ "start": 27,
+ "type": "FULL"
+ },
+ {
+ "end": 51,
+ "start": 27,
+ "type": "SHORT"
+ },
+ {
+ "end": 40,
+ "start": 38,
+ "type": "CLIENT_INITIALIZATION"
+ },
+ {
+ "end": 45,
+ "start": 41,
+ "type": "REQUEST_INITIALIZATION"
+ },
+ {
+ "end": 48,
+ "start": 46,
+ "type": "REQUEST_EXECUTION"
+ },
+ {
+ "end": 52,
+ "start": 49,
+ "type": "RESPONSE_HANDLING"
+ }
+ ],
+ "title": "aiplatform_v1_generated_tensorboard_service_read_tensorboard_size_sync.py"
+ },
{
"canonical": true,
"clientMethod": {
diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
index 003c898b31..4eb8b76b32 100644
--- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
+++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
@@ -8,7 +8,7 @@
],
"language": "PYTHON",
"name": "google-cloud-aiplatform",
- "version": "1.30.1"
+ "version": "1.31.0"
},
"snippets": [
{
@@ -27734,6 +27734,175 @@
],
"title": "aiplatform_v1beta1_generated_pipeline_service_list_training_pipelines_sync.py"
},
+ {
+ "canonical": true,
+ "clientMethod": {
+ "async": true,
+ "client": {
+ "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceAsyncClient",
+ "shortName": "PredictionServiceAsyncClient"
+ },
+ "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceAsyncClient.count_tokens",
+ "method": {
+ "fullName": "google.cloud.aiplatform.v1beta1.PredictionService.CountTokens",
+ "service": {
+ "fullName": "google.cloud.aiplatform.v1beta1.PredictionService",
+ "shortName": "PredictionService"
+ },
+ "shortName": "CountTokens"
+ },
+ "parameters": [
+ {
+ "name": "request",
+ "type": "google.cloud.aiplatform_v1beta1.types.CountTokensRequest"
+ },
+ {
+ "name": "endpoint",
+ "type": "str"
+ },
+ {
+ "name": "instances",
+ "type": "MutableSequence[google.protobuf.struct_pb2.Value]"
+ },
+ {
+ "name": "retry",
+ "type": "google.api_core.retry.Retry"
+ },
+ {
+ "name": "timeout",
+ "type": "float"
+ },
+ {
+ "name": "metadata",
+ "type": "Sequence[Tuple[str, str]"
+ }
+ ],
+ "resultType": "google.cloud.aiplatform_v1beta1.types.CountTokensResponse",
+ "shortName": "count_tokens"
+ },
+ "description": "Sample for CountTokens",
+ "file": "aiplatform_v1beta1_generated_prediction_service_count_tokens_async.py",
+ "language": "PYTHON",
+ "origin": "API_DEFINITION",
+ "regionTag": "aiplatform_v1beta1_generated_PredictionService_CountTokens_async",
+ "segments": [
+ {
+ "end": 55,
+ "start": 27,
+ "type": "FULL"
+ },
+ {
+ "end": 55,
+ "start": 27,
+ "type": "SHORT"
+ },
+ {
+ "end": 40,
+ "start": 38,
+ "type": "CLIENT_INITIALIZATION"
+ },
+ {
+ "end": 49,
+ "start": 41,
+ "type": "REQUEST_INITIALIZATION"
+ },
+ {
+ "end": 52,
+ "start": 50,
+ "type": "REQUEST_EXECUTION"
+ },
+ {
+ "end": 56,
+ "start": 53,
+ "type": "RESPONSE_HANDLING"
+ }
+ ],
+ "title": "aiplatform_v1beta1_generated_prediction_service_count_tokens_async.py"
+ },
+ {
+ "canonical": true,
+ "clientMethod": {
+ "client": {
+ "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceClient",
+ "shortName": "PredictionServiceClient"
+ },
+ "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceClient.count_tokens",
+ "method": {
+ "fullName": "google.cloud.aiplatform.v1beta1.PredictionService.CountTokens",
+ "service": {
+ "fullName": "google.cloud.aiplatform.v1beta1.PredictionService",
+ "shortName": "PredictionService"
+ },
+ "shortName": "CountTokens"
+ },
+ "parameters": [
+ {
+ "name": "request",
+ "type": "google.cloud.aiplatform_v1beta1.types.CountTokensRequest"
+ },
+ {
+ "name": "endpoint",
+ "type": "str"
+ },
+ {
+ "name": "instances",
+ "type": "MutableSequence[google.protobuf.struct_pb2.Value]"
+ },
+ {
+ "name": "retry",
+ "type": "google.api_core.retry.Retry"
+ },
+ {
+ "name": "timeout",
+ "type": "float"
+ },
+ {
+ "name": "metadata",
+ "type": "Sequence[Tuple[str, str]"
+ }
+ ],
+ "resultType": "google.cloud.aiplatform_v1beta1.types.CountTokensResponse",
+ "shortName": "count_tokens"
+ },
+ "description": "Sample for CountTokens",
+ "file": "aiplatform_v1beta1_generated_prediction_service_count_tokens_sync.py",
+ "language": "PYTHON",
+ "origin": "API_DEFINITION",
+ "regionTag": "aiplatform_v1beta1_generated_PredictionService_CountTokens_sync",
+ "segments": [
+ {
+ "end": 55,
+ "start": 27,
+ "type": "FULL"
+ },
+ {
+ "end": 55,
+ "start": 27,
+ "type": "SHORT"
+ },
+ {
+ "end": 40,
+ "start": 38,
+ "type": "CLIENT_INITIALIZATION"
+ },
+ {
+ "end": 49,
+ "start": 41,
+ "type": "REQUEST_INITIALIZATION"
+ },
+ {
+ "end": 52,
+ "start": 50,
+ "type": "REQUEST_EXECUTION"
+ },
+ {
+ "end": 56,
+ "start": 53,
+ "type": "RESPONSE_HANDLING"
+ }
+ ],
+ "title": "aiplatform_v1beta1_generated_prediction_service_count_tokens_sync.py"
+ },
{
"canonical": true,
"clientMethod": {
diff --git a/samples/snippets/noxfile.py b/samples/snippets/noxfile.py
index 1224cbe212..7c8a63994c 100644
--- a/samples/snippets/noxfile.py
+++ b/samples/snippets/noxfile.py
@@ -160,7 +160,6 @@ def blacken(session: nox.sessions.Session) -> None:
# format = isort + black
#
-
@nox.session
def format(session: nox.sessions.Session) -> None:
"""
@@ -188,9 +187,7 @@ def _session_tests(
session: nox.sessions.Session, post_install: Callable = None
) -> None:
# check for presence of tests
- test_list = glob.glob("**/*_test.py", recursive=True) + glob.glob(
- "**/test_*.py", recursive=True
- )
+ test_list = glob.glob("**/*_test.py", recursive=True) + glob.glob("**/test_*.py", recursive=True)
test_list.extend(glob.glob("**/tests", recursive=True))
if len(test_list) == 0:
@@ -212,7 +209,9 @@ def _session_tests(
if os.path.exists("requirements-test.txt"):
if os.path.exists("constraints-test.txt"):
- session.install("-r", "requirements-test.txt", "-c", "constraints-test.txt")
+ session.install(
+ "-r", "requirements-test.txt", "-c", "constraints-test.txt"
+ )
else:
session.install("-r", "requirements-test.txt")
with open("requirements-test.txt") as rtfile:
@@ -225,9 +224,9 @@ def _session_tests(
post_install(session)
if "pytest-parallel" in packages:
- concurrent_args.extend(["--workers", "auto", "--tests-per-worker", "auto"])
+ concurrent_args.extend(['--workers', 'auto', '--tests-per-worker', 'auto'])
elif "pytest-xdist" in packages:
- concurrent_args.extend(["-n", "auto"])
+ concurrent_args.extend(['-n', 'auto'])
session.run(
"pytest",
@@ -257,7 +256,7 @@ def py(session: nox.sessions.Session) -> None:
def _get_repo_root() -> Optional[str]:
- """Returns the root folder of the project."""
+ """ Returns the root folder of the project. """
# Get root of this repository. Assume we don't have directories nested deeper than 10 items.
p = Path(os.getcwd())
for i in range(10):
diff --git a/setup.py b/setup.py
index e5a1dae6ab..0b16e349a0 100644
--- a/setup.py
+++ b/setup.py
@@ -59,14 +59,16 @@
"pyarrow >= 6.0.1",
]
pipelines_extra_require = [
- "pyyaml>=5.3,<7",
+ "pyyaml==5.3.1",
]
datasets_extra_require = [
- "pyarrow >= 3.0.0, < 8.0dev",
+ "pyarrow >= 3.0.0, < 8.0dev; python_version<'3.11'",
+ "pyarrow >= 10.0.1; python_version>='3.11'",
]
vizier_extra_require = [
- "google-vizier==0.0.4",
+ "google-vizier==0.0.4; python_version<'3.11'",
+ "google-vizier>=0.1.6; python_version>='3.11'",
]
prediction_extra_require = [
diff --git a/tests/system/aiplatform/test_e2e_tabular.py b/tests/system/aiplatform/test_e2e_tabular.py
index f6700fa099..20b49999d3 100644
--- a/tests/system/aiplatform/test_e2e_tabular.py
+++ b/tests/system/aiplatform/test_e2e_tabular.py
@@ -106,6 +106,7 @@ def test_end_to_end_tabular(self, shared_state):
enable_web_access=True,
sync=False,
create_request_timeout=None,
+ disable_retries=True,
)
automl_model = automl_job.run(
diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py
index 1a281d671c..607d523d24 100644
--- a/tests/system/aiplatform/test_language_models.py
+++ b/tests/system/aiplatform/test_language_models.py
@@ -22,6 +22,9 @@
job_state as gca_job_state,
)
from tests.system.aiplatform import e2e_base
+from google.cloud.aiplatform.utils import gcs_utils
+from vertexai import language_models
+from vertexai.preview import language_models as preview_language_models
from vertexai.preview.language_models import (
ChatModel,
InputOutputTextPair,
@@ -29,6 +32,8 @@
TextEmbeddingModel,
)
+STAGING_DIR_URI = "gs://ucaip-samples-us-central1/tmp/staging"
+
class TestLanguageModels(e2e_base.TestEndToEnd):
"""System tests for language models."""
@@ -46,8 +51,23 @@ def test_text_generation(self):
temperature=0,
top_p=1,
top_k=5,
+ stop_sequences=["# %%"],
).text
+ def test_text_generation_streaming(self):
+ aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
+
+ model = TextGenerationModel.from_pretrained("google/text-bison@001")
+
+ for response in model.predict_streaming(
+ "What is the best recipe for banana bread? Recipe:",
+ max_output_tokens=128,
+ temperature=0,
+ top_p=1,
+ top_k=5,
+ ):
+ assert response.text
+
def test_chat_on_chat_model(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
@@ -65,6 +85,7 @@ def test_chat_on_chat_model(self):
),
],
temperature=0.0,
+ stop_sequences=["# %%"],
)
message1 = "Are my favorite movies based on a book series?"
@@ -86,21 +107,65 @@ def test_chat_on_chat_model(self):
assert chat.message_history[2].content == message2
assert chat.message_history[3].author == chat.MODEL_AUTHOR
+ def test_chat_model_send_message_streaming(self):
+ aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
+
+ chat_model = ChatModel.from_pretrained("google/chat-bison@001")
+ chat = chat_model.start_chat(
+ context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
+ examples=[
+ InputOutputTextPair(
+ input_text="Who do you work for?",
+ output_text="I work for Ned.",
+ ),
+ InputOutputTextPair(
+ input_text="What do I like?",
+ output_text="Ned likes watching movies.",
+ ),
+ ],
+ temperature=0.0,
+ )
+
+ message1 = "Are my favorite movies based on a book series?"
+ for response in chat.send_message_streaming(message1):
+ assert response.text
+ assert len(chat.message_history) == 2
+ assert chat.message_history[0].author == chat.USER_AUTHOR
+ assert chat.message_history[0].content == message1
+ assert chat.message_history[1].author == chat.MODEL_AUTHOR
+
+ message2 = "When were these books published?"
+ for response2 in chat.send_message_streaming(
+ message2,
+ temperature=0.1,
+ ):
+ assert response2.text
+ assert len(chat.message_history) == 4
+ assert chat.message_history[2].author == chat.USER_AUTHOR
+ assert chat.message_history[2].content == message2
+ assert chat.message_history[3].author == chat.MODEL_AUTHOR
+
def test_text_embedding(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
model = TextEmbeddingModel.from_pretrained("google/textembedding-gecko@001")
- embeddings = model.get_embeddings(["What is life?"])
- assert embeddings
- for embedding in embeddings:
- vector = embedding.values
- assert len(vector) == 768
+ # One short text, one llong text (to check truncation)
+ texts = ["What is life?", "What is life?" * 1000]
+ embeddings = model.get_embeddings(texts)
+ assert len(embeddings) == 2
+ assert len(embeddings[0].values) == 768
+ assert embeddings[0].statistics.token_count > 0
+ assert not embeddings[0].statistics.truncated
+
+ assert len(embeddings[1].values) == 768
+ assert embeddings[1].statistics.token_count > 1000
+ assert embeddings[1].statistics.truncated
def test_tuning(self, shared_state):
"""Test tuning, listing and loading models."""
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
- model = TextGenerationModel.from_pretrained("google/text-bison@001")
+ model = language_models.TextGenerationModel.from_pretrained("text-bison@001")
import pandas
@@ -119,12 +184,24 @@ def test_tuning(self, shared_state):
]
)
- model.tune_model(
+ dataset_uri = (
+ STAGING_DIR_URI + "/veretx_llm_tuning_training_data.text-bison.dummy.jsonl"
+ )
+ gcs_utils._upload_pandas_df_to_gcs(
+ df=training_data, upload_gcs_path=dataset_uri
+ )
+
+ tuning_job = model.tune_model(
training_data=training_data,
train_steps=1,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
- learning_rate=2.0,
+ learning_rate_multiplier=2.0,
+ tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
+ evaluation_data=dataset_uri,
+ evaluation_interval=37,
+ enable_early_stopping=True,
+ ),
)
# According to the Pipelines design, external resources created by a pipeline
# must not be modified or deleted. Otherwise caching will break next pipeline runs.
@@ -136,6 +213,18 @@ def test_tuning(self, shared_state):
)
# Deleting the Endpoint is a little less bad since the LLM SDK will recreate it, but it's not advised for the same reason.
+ # Testing the new model returned by the `tuning_job.get_tuned_model` method
+ tuned_model1 = tuning_job.get_tuned_model()
+ response1 = tuned_model1.predict(
+ "What is the best recipe for banana bread? Recipe:",
+ max_output_tokens=128,
+ temperature=0,
+ top_p=1,
+ top_k=5,
+ )
+ assert response1.text
+
+ # Testing the model updated in-place (Deprecated. Preview only)
response = model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
@@ -199,3 +288,27 @@ def test_batch_prediction_for_textembedding(self):
job.delete()
assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED
+
+ def test_code_generation_streaming(self):
+ aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
+
+ model = language_models.CodeGenerationModel.from_pretrained("code-bison@001")
+
+ for response in model.predict_streaming(
+ prefix="def reverse_string(s):",
+ # code-bison does not support suffix
+ # suffix=" return s",
+ max_output_tokens=128,
+ temperature=0,
+ ):
+ assert response.text
+
+ def test_code_chat_model_send_message_streaming(self):
+ aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
+
+ chat_model = language_models.ChatModel.from_pretrained("codeodechat-bison@001")
+ chat = chat_model.start_chat()
+
+ message1 = "Please help write a function to calculate the max of two numbers"
+ for response in chat.send_message_streaming(message1):
+ assert response.text
diff --git a/tests/system/aiplatform/test_model_evaluation.py b/tests/system/aiplatform/test_model_evaluation.py
index 733ea0dc1d..565edaefa6 100644
--- a/tests/system/aiplatform/test_model_evaluation.py
+++ b/tests/system/aiplatform/test_model_evaluation.py
@@ -83,6 +83,7 @@ def staging_bucket(self, storage_client):
yield bucket
def test_model_evaluate_custom_tabular_model(self, staging_bucket, shared_state):
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
custom_model = aiplatform.Model(
model_name=_TEST_PERMANENT_CUSTOM_MODEL_CLASSIFICATION_RESOURCE_NAME
diff --git a/tests/system/aiplatform/test_pipeline_job_schedule.py b/tests/system/aiplatform/test_pipeline_job_schedule.py
index 44c98a1d69..7eab611a21 100644
--- a/tests/system/aiplatform/test_pipeline_job_schedule.py
+++ b/tests/system/aiplatform/test_pipeline_job_schedule.py
@@ -16,16 +16,11 @@
#
from google.cloud import aiplatform
-from google.cloud.aiplatform.compat.types import (
- schedule_v1beta1 as gca_schedule,
-)
-from google.cloud.aiplatform.preview.pipelinejobschedule import (
- pipeline_job_schedules,
-)
+from google.cloud.aiplatform import pipeline_job_schedules
+from google.cloud.aiplatform.compat.types import schedule as gca_schedule
from tests.system.aiplatform import e2e_base
-
-from kfp import components
-from kfp.v2 import compiler
+from kfp import compiler
+from kfp import dsl
import pytest
from google.protobuf.json_format import MessageToDict
@@ -34,11 +29,12 @@
@pytest.mark.usefixtures(
"tear_down_resources", "prepare_staging_bucket", "delete_staging_bucket"
)
-class TestPreviewPipelineJobSchedule(e2e_base.TestEndToEnd):
+class TestPipelineJobSchedule(e2e_base.TestEndToEnd):
_temp_prefix = "tmpvrtxsdk-e2e-pjs"
def test_create_get_pause_resume_update_list(self, shared_state):
# Components:
+ @dsl.component
def train(
number_of_epochs: int,
learning_rate: float,
@@ -46,13 +42,12 @@ def train(
print(f"number_of_epochs={number_of_epochs}")
print(f"learning_rate={learning_rate}")
- train_op = components.create_component_from_func(train)
-
# Pipeline:
+ @dsl.pipeline(name="system-test-training-pipeline")
def training_pipeline(number_of_epochs: int = 2):
- train_op(
+ train(
number_of_epochs=number_of_epochs,
- learning_rate="0.1",
+ learning_rate=0.1,
)
# Creating the pipeline job schedule.
@@ -61,11 +56,10 @@ def training_pipeline(number_of_epochs: int = 2):
location=e2e_base._LOCATION,
)
- ir_file = "pipeline.json"
+ ir_file = "pipeline.yaml"
compiler.Compiler().compile(
pipeline_func=training_pipeline,
package_path=ir_file,
- pipeline_name="system-test-training-pipeline",
)
job = aiplatform.PipelineJob(
template_path=ir_file,
@@ -77,9 +71,9 @@ def training_pipeline(number_of_epochs: int = 2):
)
max_run_count = 2
- cron_expression = "*/5 * * * *"
+ cron = "*/5 * * * *"
pipeline_job_schedule.create(
- cron_expression=cron_expression,
+ cron=cron,
max_run_count=max_run_count,
max_concurrent_run_count=2,
)
@@ -90,13 +84,13 @@ def training_pipeline(number_of_epochs: int = 2):
pipeline_job_schedule.pause()
assert pipeline_job_schedule.state == gca_schedule.Schedule.State.PAUSED
- # Before updating, confirm the cron_expression is correctly set from the create step.
- assert pipeline_job_schedule.cron_expression == cron_expression
+ # Before updating, confirm cron is correctly set from the create step.
+ assert pipeline_job_schedule.cron == cron
# Updating the pipeline job schedule.
- new_cron_expression = "* * * * *"
- pipeline_job_schedule.update(cron_expression=new_cron_expression)
- assert pipeline_job_schedule.cron_expression == new_cron_expression
+ new_cron = "* * * * *"
+ pipeline_job_schedule.update(cron=new_cron)
+ assert pipeline_job_schedule.cron == new_cron
# Resuming the pipeline job schedule.
pipeline_job_schedule.resume(catch_up=True)
@@ -105,12 +99,12 @@ def training_pipeline(number_of_epochs: int = 2):
pipeline_job_schedule.wait()
# Confirming that correct number of runs were scheduled and completed by this pipeline job schedule.
- list_jobs_with_read_mask = pipeline_job_schedule.list_jobs(
- enable_simple_view=True
- )
+ list_jobs_with_read_mask = pipeline_job_schedule.list_jobs()
assert len(list_jobs_with_read_mask) == max_run_count
- list_jobs_without_read_mask = pipeline_job_schedule.list_jobs()
+ list_jobs_without_read_mask = pipeline_job_schedule.list_jobs(
+ enable_simple_view=False
+ )
# enable_simple_view=True should apply the `read_mask` filter to limit PipelineJob fields returned
assert "serviceAccount" in MessageToDict(
diff --git a/tests/system/aiplatform/test_vision_models.py b/tests/system/aiplatform/test_vision_models.py
index ddf7cf7168..f30628d036 100644
--- a/tests/system/aiplatform/test_vision_models.py
+++ b/tests/system/aiplatform/test_vision_models.py
@@ -22,6 +22,7 @@
from google.cloud import aiplatform
from tests.system.aiplatform import e2e_base
+from vertexai import vision_models as ga_vision_models
from vertexai.preview import vision_models
from PIL import Image as PIL_Image
@@ -45,7 +46,7 @@ class VisionModelTestSuite(e2e_base.TestEndToEnd):
def test_image_captioning_model_get_captions(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
- model = vision_models.ImageCaptioningModel.from_pretrained("imagetext")
+ model = ga_vision_models.ImageCaptioningModel.from_pretrained("imagetext")
image = _create_blank_image()
captions = model.get_captions(
image=image,
@@ -58,7 +59,7 @@ def test_image_captioning_model_get_captions(self):
def test_image_q_and_a_model_ask_question(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
- model = vision_models.ImageQnAModel.from_pretrained("imagetext")
+ model = ga_vision_models.ImageQnAModel.from_pretrained("imagetext")
image = _create_blank_image()
answers = model.ask_question(
image=image,
@@ -71,7 +72,7 @@ def test_image_q_and_a_model_ask_question(self):
def test_multi_modal_embedding_model(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
- model = vision_models.MultiModalEmbeddingModel.from_pretrained(
+ model = ga_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)
image = _create_blank_image()
@@ -83,3 +84,82 @@ def test_multi_modal_embedding_model(self):
# The service is expected to return the embeddings of size 1408
assert len(embeddings.image_embedding) == 1408
assert len(embeddings.text_embedding) == 1408
+
+ def test_image_generation_model_generate_images(self):
+ """Tests the image generation model generating images."""
+ model = vision_models.ImageGenerationModel.from_pretrained(
+ "imagegeneration@001"
+ )
+
+ # TODO(b/295946075): The service stopped supporting image sizes.
+ # width = 1024
+ # height = 768
+ number_of_images = 4
+ seed = 1
+ guidance_scale = 15
+
+ prompt1 = "Astronaut riding a horse"
+ negative_prompt1 = "bad quality"
+ image_response = model.generate_images(
+ prompt=prompt1,
+ # Optional:
+ negative_prompt=negative_prompt1,
+ number_of_images=number_of_images,
+ # TODO(b/295946075): The service stopped supporting image sizes.
+ # width=width,
+ # height=height,
+ seed=seed,
+ guidance_scale=guidance_scale,
+ )
+
+ assert len(image_response.images) == number_of_images
+ for idx, image in enumerate(image_response):
+ # TODO(b/295946075): The service stopped supporting image sizes.
+ # assert image._pil_image.size == (width, height)
+ assert image.generation_parameters
+ assert image.generation_parameters["prompt"] == prompt1
+ assert image.generation_parameters["negative_prompt"] == negative_prompt1
+ # TODO(b/295946075): The service stopped supporting image sizes.
+ # assert image.generation_parameters["width"] == width
+ # assert image.generation_parameters["height"] == height
+ assert image.generation_parameters["seed"] == seed
+ assert image.generation_parameters["guidance_scale"] == guidance_scale
+ assert image.generation_parameters["index_of_image_in_batch"] == idx
+
+ # Test saving and loading images
+ with tempfile.TemporaryDirectory() as temp_dir:
+ image_path = os.path.join(temp_dir, "image.png")
+ image_response[0].save(location=image_path)
+ image1 = vision_models.GeneratedImage.load_from_file(image_path)
+ # assert image1._pil_image.size == (width, height)
+ assert image1.generation_parameters
+ assert image1.generation_parameters["prompt"] == prompt1
+
+ # Preparing mask
+ mask_path = os.path.join(temp_dir, "mask.png")
+ mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size)
+ mask_pil_image.save(mask_path, format="PNG")
+ mask_image = vision_models.Image.load_from_file(mask_path)
+
+ # Test generating image from base image
+ prompt2 = "Ancient book style"
+ image_response2 = model.edit_image(
+ prompt=prompt2,
+ # Optional:
+ number_of_images=number_of_images,
+ seed=seed,
+ guidance_scale=guidance_scale,
+ base_image=image1,
+ mask=mask_image,
+ )
+ assert len(image_response2.images) == number_of_images
+ for idx, image in enumerate(image_response2):
+ # TODO(b/295946075): The service stopped supporting image sizes.
+ # assert image._pil_image.size == (width, height)
+ assert image.generation_parameters
+ assert image.generation_parameters["prompt"] == prompt2
+ assert image.generation_parameters["seed"] == seed
+ assert image.generation_parameters["guidance_scale"] == guidance_scale
+ assert image.generation_parameters["index_of_image_in_batch"] == idx
+ assert "base_image_hash" in image.generation_parameters
+ assert "mask_hash" in image.generation_parameters
diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py
index 4ca82c7746..80faded765 100644
--- a/tests/unit/aiplatform/constants.py
+++ b/tests/unit/aiplatform/constants.py
@@ -125,6 +125,7 @@ class TrainingJobConstants:
)
_TEST_TIMEOUT = 8000
_TEST_RESTART_JOB_ON_WORKER_RESTART = True
+ _TEST_DISABLE_RETRIES = True
_TEST_BASE_CUSTOM_JOB_PROTO = custom_job.CustomJob(
display_name=_TEST_DISPLAY_NAME,
@@ -136,6 +137,7 @@ class TrainingJobConstants:
scheduling=custom_job.Scheduling(
timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
+ disable_retries=_TEST_DISABLE_RETRIES,
),
service_account=ProjectConstants._TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
diff --git a/tests/unit/aiplatform/test_custom_job.py b/tests/unit/aiplatform/test_custom_job.py
index e91f90cefd..ea43c42a4f 100644
--- a/tests/unit/aiplatform/test_custom_job.py
+++ b/tests/unit/aiplatform/test_custom_job.py
@@ -126,6 +126,7 @@
_TEST_RESTART_JOB_ON_WORKER_RESTART = (
test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART
)
+_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES
_TEST_LABELS = test_constants.ProjectConstants._TEST_LABELS
@@ -421,6 +422,7 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
@@ -465,6 +467,7 @@ def test_submit_custom_job(self, create_custom_job_mock, get_custom_job_mock):
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
@@ -516,6 +519,7 @@ def test_submit_custom_job_with_experiments(
create_request_timeout=None,
experiment=_TEST_EXPERIMENT,
experiment_run=_TEST_EXPERIMENT_RUN,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
@@ -569,6 +573,7 @@ def test_create_custom_job_with_timeout(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=180.0,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
@@ -610,6 +615,7 @@ def test_create_custom_job_with_timeout_not_explicitly_set(
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
@@ -656,6 +662,7 @@ def test_run_custom_job_with_fail_raises(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait()
@@ -696,6 +703,7 @@ def test_run_custom_job_with_fail_at_creation(self):
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=False,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
with pytest.raises(RuntimeError) as e:
@@ -1012,6 +1020,7 @@ def test_create_custom_job_with_enable_web_access(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
@@ -1083,6 +1092,7 @@ def test_create_custom_job_with_tensorboard(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait()
@@ -1149,6 +1159,7 @@ def test_check_custom_job_availability(self):
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
diff --git a/tests/unit/aiplatform/test_custom_job_persistent_resource.py b/tests/unit/aiplatform/test_custom_job_persistent_resource.py
index 3405feb9da..3b23c05fcd 100644
--- a/tests/unit/aiplatform/test_custom_job_persistent_resource.py
+++ b/tests/unit/aiplatform/test_custom_job_persistent_resource.py
@@ -71,6 +71,7 @@
_TEST_RESTART_JOB_ON_WORKER_RESTART = (
test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART
)
+_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES
_TEST_LABELS = test_constants.ProjectConstants._TEST_LABELS
@@ -87,6 +88,7 @@
scheduling=custom_job_v1beta1.Scheduling(
timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
+ disable_retries=_TEST_DISABLE_RETRIES,
),
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
@@ -175,6 +177,7 @@ def test_create_custom_job_with_persistent_resource(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
@@ -222,6 +225,7 @@ def test_submit_custom_job_with_persistent_resource(
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait_for_resource_creation()
diff --git a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py
index a6685c5285..3f6be1dc3d 100644
--- a/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py
+++ b/tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py
@@ -116,6 +116,7 @@ def call(self, inputs):
}
assert expected_md == generated_md
+ @pytest.mark.skip(reason="Failing for Python 3.11, tracked in b/293506827.")
def test_non_keras_model(self):
class CustomModuleWithOutputName(tf.Module):
def __init__(self):
diff --git a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py
index 911115bb59..c625b7b442 100644
--- a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py
+++ b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py
@@ -66,6 +66,7 @@
_TEST_RESTART_JOB_ON_WORKER_RESTART = (
test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART
)
+_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES
_TEST_METRIC_SPEC_KEY = "test-metric"
_TEST_METRIC_SPEC_VALUE = "maximize"
@@ -448,6 +449,7 @@ def test_create_hyperparameter_tuning_job(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait()
@@ -519,6 +521,7 @@ def test_create_hyperparameter_tuning_job_with_timeout(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=180.0,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait()
@@ -586,6 +589,7 @@ def test_run_hyperparameter_tuning_job_with_fail_raises(
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait()
@@ -647,6 +651,7 @@ def test_run_hyperparameter_tuning_job_with_fail_at_creation(self):
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=False,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
with pytest.raises(RuntimeError) as e:
@@ -783,6 +788,7 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
tensorboard=test_constants.TensorboardConstants._TEST_TENSORBOARD_NAME,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait()
@@ -860,6 +866,7 @@ def test_create_hyperparameter_tuning_job_with_enable_web_access(
enable_web_access=test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
job.wait()
diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py
index 1c4781c9bc..c3836f88cb 100644
--- a/tests/unit/aiplatform/test_language_models.py
+++ b/tests/unit/aiplatform/test_language_models.py
@@ -28,6 +28,7 @@
from google.cloud import storage
from google.cloud import aiplatform
+from google.cloud.aiplatform import _streaming_prediction
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import gcs_utils
@@ -41,6 +42,7 @@
)
from google.cloud.aiplatform.compat.services import prediction_service_client
from google.cloud.aiplatform.compat.types import (
+ artifact as gca_artifact,
prediction_service as gca_prediction_service,
context as gca_context,
endpoint as gca_endpoint,
@@ -58,6 +60,9 @@
)
from vertexai import language_models
from vertexai.language_models import _language_models
+from vertexai.language_models import (
+ _evaluatable_language_models,
+)
from google.cloud.aiplatform_v1 import Execution as GapicExecution
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec,
@@ -164,6 +169,34 @@
1. Preheat oven to 350 degrees F (175 degrees C).""",
}
+_TEST_TEXT_GENERATION_PREDICTION_STREAMING = [
+ {
+ "content": "1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.",
+ },
+ {
+ "content": " 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.",
+ "safetyAttributes": {"blocked": False, "categories": None, "scores": None},
+ },
+ {
+ "content": " 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45.",
+ "citationMetadata": {
+ "citations": [
+ {
+ "title": "THEATRUM ARITHMETICO-GEOMETRICUM",
+ "publicationDate": "1727",
+ "endIndex": 181,
+ "startIndex": 12,
+ }
+ ]
+ },
+ "safetyAttributes": {
+ "blocked": True,
+ "categories": ["Finance"],
+ "scores": [0.1],
+ },
+ },
+]
+
_TEST_CHAT_GENERATION_PREDICTION1 = {
"safetyAttributes": [
{
@@ -195,6 +228,33 @@
],
}
+_TEST_CHAT_PREDICTION_STREAMING = [
+ {
+ "candidates": [
+ {
+ "author": "1",
+ "content": "1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.",
+ }
+ ],
+ "safetyAttributes": [{"blocked": False, "categories": None, "scores": None}],
+ },
+ {
+ "candidates": [
+ {
+ "author": "1",
+ "content": " 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.",
+ }
+ ],
+ "safetyAttributes": [
+ {
+ "blocked": True,
+ "categories": ["Finance"],
+ "scores": [0.1],
+ }
+ ],
+ },
+]
+
_TEST_CODE_GENERATION_PREDICTION = {
"safetyAttributes": {
"categories": [],
@@ -238,6 +298,7 @@ def reverse_string_2(s):""",
_TEST_TEXT_EMBEDDING_PREDICTION = {
"embeddings": {
"values": list([1.0] * _TEXT_EMBEDDING_VECTOR_LENGTH),
+ "statistics": {"truncated": False, "token_count": 4.0},
}
}
@@ -279,29 +340,74 @@ def reverse_string_2(s):""",
"isOptional": True,
"parameterType": "STRING",
},
+ "default_context": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "enable_early_stopping": {
+ "defaultValue": True,
+ "isOptional": True,
+ "parameterType": "BOOLEAN",
+ },
"encryption_spec_key_name": {
"defaultValue": "",
"isOptional": True,
"parameterType": "STRING",
},
+ "evaluation_data_uri": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "evaluation_interval": {
+ "defaultValue": 20,
+ "isOptional": True,
+ "parameterType": "NUMBER_INTEGER",
+ },
+ "evaluation_output_root_dir": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
"large_model_reference": {
- "defaultValue": "text-bison-001",
+ "defaultValue": "text-bison@001",
"isOptional": True,
"parameterType": "STRING",
},
"learning_rate": {
- "defaultValue": 3,
+ "defaultValue": -1,
+ "isOptional": True,
+ "parameterType": "NUMBER_DOUBLE",
+ },
+ "learning_rate_multiplier": {
+ "defaultValue": 1,
"isOptional": True,
"parameterType": "NUMBER_DOUBLE",
},
"location": {"parameterType": "STRING"},
"model_display_name": {"parameterType": "STRING"},
"project": {"parameterType": "STRING"},
+ "tensorboard_resource_id": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "tpu_training_skip_cmek": {
+ "defaultValue": False,
+ "isOptional": True,
+ "parameterType": "BOOLEAN",
+ },
"train_steps": {
- "defaultValue": 1000,
+ "defaultValue": 300,
"isOptional": True,
"parameterType": "NUMBER_INTEGER",
},
+ "tuning_method": {
+ "defaultValue": "tune_v2",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
}
},
},
@@ -321,6 +427,251 @@ def reverse_string_2(s):""",
}
)
+_TEST_TEXT_GENERATION_METRICS = {
+ "bleu": 3.9311041439597427,
+ "rougeLSum": 19.014677479620463,
+}
+
+
+_TEST_TEXT_CLASSIFICATION_METRICS = {"auPrc": 0.9, "auRoc": 0.8, "logLoss": 0.5}
+
+_TEST_EVAL_DATA = [
+ {
+ "prompt": "Basketball teams in the Midwest.",
+ "ground_truth": "There are several basketball teams located in the Midwest region of the United States. Here are some of them:",
+ },
+ {
+ "prompt": "How to bake gluten-free bread?",
+ "ground_truth": "Baking gluten-free bread can be a bit challenging because gluten is the protein that gives bread its structure and elasticity.",
+ },
+]
+
+_TEST_EVAL_DATA_DF = pd.DataFrame(_TEST_EVAL_DATA)
+
+_TEST_ARTIFACT_ID = "123456"
+_TEST_ARTIFACT_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default/artifacts/{_TEST_ARTIFACT_ID}"
+
+_TEST_EVAL_PIPELINE_SPEC = {
+ "components": {},
+ "pipelineInfo": {"name": "evaluation-llm-text-generation-pipeline"},
+ "root": {
+ "dag": {"tasks": {}},
+ "inputDefinitions": {
+ "parameters": {
+ "batch_predict_accelerator_count": {
+ "defaultValue": 0.0,
+ "isOptional": True,
+ "parameterType": "NUMBER_INTEGER",
+ },
+ "batch_predict_accelerator_type": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "batch_predict_gcs_source_uris": {
+ "defaultValue": [],
+ "isOptional": True,
+ "parameterType": "LIST",
+ },
+ "batch_predict_gcs_destination_output_uri": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "batch_predict_predictions_format": {
+ "defaultValue": "jsonl",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "enable_web_access": {
+ "defaultValue": True,
+ "isOptional": True,
+ "parameterType": "BOOLEAN",
+ },
+ "encryption_spec_key_name": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "evaluation_display_name": {
+ "defaultValue": "evaluation-text-generation",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "location": {
+ "defaultValue": "us-central1",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "machine_type": {
+ "defaultValue": "e2-highmem-16",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "model_name": {"parameterType": "STRING"},
+ "evaluation_task": {"parameterType": "STRING"},
+ "network": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "nlp_task": {
+ "defaultValue": "text-generation",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "predictions_format": {
+ "defaultValue": "jsonl",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "predictions_gcs_source": {
+ "defaultValue": [],
+ "isOptional": True,
+ "parameterType": "LIST",
+ },
+ "project": {"parameterType": "STRING"},
+ "root_dir": {"parameterType": "STRING"},
+ "service_account": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ }
+ },
+ },
+ "schemaVersion": "2.1.0",
+ "sdkVersion": "kfp-2.0.0-beta.14",
+}
+
+
+_TEST_EVAL_PIPELINE_SPEC_JSON = json.dumps(
+ _TEST_EVAL_PIPELINE_SPEC,
+)
+
+_TEST_EVAL_PIPELINE_JOB = json.dumps(
+ {
+ "runtimeConfig": {"parameterValues": {}},
+ "pipelineSpec": json.loads(_TEST_EVAL_PIPELINE_SPEC_JSON),
+ }
+)
+
+# Eval classification spec
+
+_TEST_EVAL_CLASSIFICATION_PIPELINE_SPEC = {
+ "components": {},
+ "pipelineInfo": {"name": "evaluation-llm-text-generation-pipeline"},
+ "root": {
+ "dag": {"tasks": {}},
+ "inputDefinitions": {
+ "parameters": {
+ "batch_predict_accelerator_count": {
+ "defaultValue": 0.0,
+ "isOptional": True,
+ "parameterType": "NUMBER_INTEGER",
+ },
+ "batch_predict_accelerator_type": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "batch_predict_gcs_source_uris": {
+ "defaultValue": [],
+ "isOptional": True,
+ "parameterType": "LIST",
+ },
+ "batch_predict_gcs_destination_output_uri": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "batch_predict_predictions_format": {
+ "defaultValue": "jsonl",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "enable_web_access": {
+ "defaultValue": True,
+ "isOptional": True,
+ "parameterType": "BOOLEAN",
+ },
+ "encryption_spec_key_name": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "evaluation_display_name": {
+ "defaultValue": "evaluation-text-generation",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "location": {
+ "defaultValue": "us-central1",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "machine_type": {
+ "defaultValue": "e2-highmem-16",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "model_name": {"parameterType": "STRING"},
+ "evaluation_task": {"parameterType": "STRING"},
+ "network": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "nlp_task": {
+ "defaultValue": "text-generation",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "predictions_format": {
+ "defaultValue": "jsonl",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "predictions_gcs_source": {
+ "defaultValue": [],
+ "isOptional": True,
+ "parameterType": "LIST",
+ },
+ "evaluation_class_labels": {
+ "defaultValue": [],
+ "isOptional": True,
+ "parameterType": "LIST",
+ },
+ "target_field_name": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ "project": {"parameterType": "STRING"},
+ "root_dir": {"parameterType": "STRING"},
+ "service_account": {
+ "defaultValue": "",
+ "isOptional": True,
+ "parameterType": "STRING",
+ },
+ }
+ },
+ },
+ "schemaVersion": "2.1.0",
+ "sdkVersion": "kfp-2.0.0-beta.14",
+}
+
+_TEST_EVAL_CLASSIFICATION_PIPELINE_SPEC_JSON = json.dumps(
+ _TEST_EVAL_CLASSIFICATION_PIPELINE_SPEC,
+)
+
+_TEST_EVAL_CLASSIFICATION_PIPELINE_JOB = json.dumps(
+ {
+ "runtimeConfig": {"parameterValues": {}},
+ "pipelineSpec": json.loads(_TEST_EVAL_PIPELINE_SPEC_JSON),
+ }
+)
+
@pytest.fixture
def mock_pipeline_bucket_exists():
@@ -377,6 +728,96 @@ def make_pipeline_job(state):
)
+def make_eval_pipeline_job(state):
+ return gca_pipeline_job.PipelineJob(
+ name=test_constants.PipelineJobConstants._TEST_PIPELINE_JOB_NAME,
+ state=state,
+ create_time=test_constants.PipelineJobConstants._TEST_PIPELINE_CREATE_TIME,
+ service_account=test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT,
+ network=test_constants.TrainingJobConstants._TEST_NETWORK,
+ job_detail=gca_pipeline_job.PipelineJobDetail(
+ pipeline_run_context=gca_context.Context(
+ name=test_constants.PipelineJobConstants._TEST_PIPELINE_JOB_NAME,
+ ),
+ task_details=[
+ gca_pipeline_job.PipelineTaskDetail(
+ task_id=456,
+ task_name=test_constants.PipelineJobConstants._TEST_PIPELINE_JOB_ID,
+ outputs={
+ "evaluation_metrics": gca_pipeline_job.PipelineTaskDetail.ArtifactList(
+ artifacts=[
+ gca_artifact.Artifact(
+ name="test-metric-artifact",
+ metadata=_TEST_TEXT_GENERATION_METRICS,
+ ),
+ ],
+ )
+ },
+ ),
+ gca_pipeline_job.PipelineTaskDetail(
+ task_id=789,
+ task_name=test_constants.PipelineJobConstants._TEST_PIPELINE_JOB_ID,
+ outputs={
+ "evaluation_metrics": gca_pipeline_job.PipelineTaskDetail.ArtifactList(
+ artifacts=[
+ gca_artifact.Artifact(
+ display_name="evaluation_metrics",
+ uri="gs://test-bucket/evaluation_metrics",
+ ),
+ ]
+ )
+ },
+ ),
+ ],
+ ),
+ )
+
+
+def make_eval_classification_pipeline_job(state):
+ return gca_pipeline_job.PipelineJob(
+ name=test_constants.PipelineJobConstants._TEST_PIPELINE_JOB_NAME,
+ state=state,
+ create_time=test_constants.PipelineJobConstants._TEST_PIPELINE_CREATE_TIME,
+ service_account=test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT,
+ network=test_constants.TrainingJobConstants._TEST_NETWORK,
+ job_detail=gca_pipeline_job.PipelineJobDetail(
+ pipeline_run_context=gca_context.Context(
+ name=test_constants.PipelineJobConstants._TEST_PIPELINE_JOB_NAME,
+ ),
+ task_details=[
+ gca_pipeline_job.PipelineTaskDetail(
+ task_id=456,
+ task_name=test_constants.PipelineJobConstants._TEST_PIPELINE_JOB_ID,
+ outputs={
+ "evaluation_metrics": gca_pipeline_job.PipelineTaskDetail.ArtifactList(
+ artifacts=[
+ gca_artifact.Artifact(
+ name="test-metric-artifact",
+ metadata=_TEST_TEXT_CLASSIFICATION_METRICS,
+ ),
+ ],
+ )
+ },
+ ),
+ gca_pipeline_job.PipelineTaskDetail(
+ task_id=789,
+ task_name=test_constants.PipelineJobConstants._TEST_PIPELINE_JOB_ID,
+ outputs={
+ "evaluation_metrics": gca_pipeline_job.PipelineTaskDetail.ArtifactList(
+ artifacts=[
+ gca_artifact.Artifact(
+ display_name="evaluation_metrics",
+ uri="gs://test-bucket/evaluation_metrics",
+ ),
+ ]
+ )
+ },
+ ),
+ ],
+ ),
+ )
+
+
@pytest.fixture
def mock_pipeline_service_create():
with mock.patch.object(
@@ -389,13 +830,35 @@ def mock_pipeline_service_create():
@pytest.fixture
-def mock_pipeline_job_get():
+def mock_pipeline_service_create_eval():
with mock.patch.object(
- pipeline_service_client.PipelineServiceClient, "get_pipeline_job"
- ) as mock_get_pipeline_job:
- mock_get_pipeline_job.side_effect = [
- make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING),
- make_pipeline_job(
+ pipeline_service_client.PipelineServiceClient, "create_pipeline_job"
+ ) as mock_create_pipeline_job:
+ mock_create_pipeline_job.return_value = make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ )
+ yield mock_create_pipeline_job
+
+
+@pytest.fixture
+def mock_pipeline_service_create_eval_classification():
+ with mock.patch.object(
+ pipeline_service_client.PipelineServiceClient, "create_pipeline_job"
+ ) as mock_create_pipeline_job:
+ mock_create_pipeline_job.return_value = make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ )
+ yield mock_create_pipeline_job
+
+
+@pytest.fixture
+def mock_pipeline_job_get():
+ with mock.patch.object(
+ pipeline_service_client.PipelineServiceClient, "get_pipeline_job"
+ ) as mock_get_pipeline_job:
+ mock_get_pipeline_job.side_effect = [
+ make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING),
+ make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
),
make_pipeline_job(
@@ -424,6 +887,82 @@ def mock_pipeline_job_get():
yield mock_get_pipeline_job
+@pytest.fixture
+def mock_pipeline_job_get_eval():
+ with mock.patch.object(
+ pipeline_service_client.PipelineServiceClient, "get_pipeline_job"
+ ) as mock_get_pipeline_job:
+ mock_get_pipeline_job.side_effect = [
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING
+ ),
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ ]
+
+ yield mock_get_pipeline_job
+
+
+@pytest.fixture
+def mock_pipeline_job_get_eval_classification():
+ with mock.patch.object(
+ pipeline_service_client.PipelineServiceClient, "get_pipeline_job"
+ ) as mock_get_pipeline_job:
+ mock_get_pipeline_job.side_effect = [
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING
+ ),
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ),
+ ]
+
+ yield mock_get_pipeline_job
+
+
@pytest.fixture
def mock_load_yaml_and_json(job_spec):
with mock.patch.object(
@@ -457,6 +996,32 @@ def mock_request_urlopen(request: str) -> Tuple[str, mock.MagicMock]:
yield request.param, mock_urlopen
+@pytest.fixture
+def mock_request_urlopen_eval(request: str) -> Tuple[str, mock.MagicMock]:
+ data = _TEST_EVAL_PIPELINE_SPEC
+ with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
+ mock_read_response = mock.MagicMock()
+ mock_decode_response = mock.MagicMock()
+ mock_decode_response.return_value = json.dumps(data)
+ mock_read_response.return_value.decode = mock_decode_response
+ mock_urlopen.return_value.read = mock_read_response
+ yield request.param, mock_urlopen
+
+
+@pytest.fixture
+def mock_request_urlopen_eval_classification(
+ request: str,
+) -> Tuple[str, mock.MagicMock]:
+ data = _TEST_EVAL_CLASSIFICATION_PIPELINE_SPEC
+ with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
+ mock_read_response = mock.MagicMock()
+ mock_decode_response = mock.MagicMock()
+ mock_decode_response.return_value = json.dumps(data)
+ mock_read_response.return_value.decode = mock_decode_response
+ mock_urlopen.return_value.read = mock_read_response
+ yield request.param, mock_urlopen
+
+
@pytest.fixture
def get_endpoint_mock():
with mock.patch.object(
@@ -474,13 +1039,13 @@ def mock_get_tuned_model(get_endpoint_mock):
with mock.patch.object(
_language_models._TunableModelMixin, "get_tuned_model"
) as mock_text_generation_model:
- mock_text_generation_model._model_id = (
+ mock_text_generation_model.return_value._model_id = (
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
)
- mock_text_generation_model._endpoint_name = (
+ mock_text_generation_model.return_value._endpoint_name = (
test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)
- mock_text_generation_model._endpoint = get_endpoint_mock
+ mock_text_generation_model.return_value._endpoint = get_endpoint_mock
yield mock_text_generation_model
@@ -523,6 +1088,48 @@ def get_endpoint_with_models_mock():
yield get_endpoint_models_mock
+# Model Evaluation fixtures
+@pytest.fixture
+def mock_model_evaluate():
+ with mock.patch.object(
+ _evaluatable_language_models._EvaluatableLanguageModel, "evaluate"
+ ) as mock_model_evaluate:
+ mock_model_evaluate.return_value = _TEST_TEXT_GENERATION_METRICS
+ yield mock_model_evaluate
+
+
+@pytest.fixture
+def mock_successfully_completed_eval_job():
+ with mock.patch.object(
+ pipeline_service_client.PipelineServiceClient, "get_pipeline_job"
+ ) as mock_get_model_eval_job:
+ mock_get_model_eval_job.return_value = make_eval_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ )
+ yield mock_get_model_eval_job
+
+
+@pytest.fixture
+def mock_successfully_completed_eval_classification_job():
+ with mock.patch.object(
+ pipeline_service_client.PipelineServiceClient, "get_pipeline_job"
+ ) as mock_get_model_eval_job:
+ mock_get_model_eval_job.return_value = make_eval_classification_pipeline_job(
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ )
+ yield mock_get_model_eval_job
+
+
+@pytest.fixture
+def mock_storage_blob_upload_from_filename():
+ with mock.patch(
+ "google.cloud.storage.Blob.upload_from_filename"
+ ) as mock_blob_upload_from_filename, mock.patch(
+ "google.cloud.storage.Bucket.exists", return_value=True
+ ):
+ yield mock_blob_upload_from_filename
+
+
@pytest.mark.usefixtures("google_auth_mock")
class TestLanguageModels:
"""Unit tests for the language models."""
@@ -530,6 +1137,10 @@ class TestLanguageModels:
def setup_method(self):
reload(initializer)
reload(aiplatform)
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ )
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)
@@ -577,6 +1188,10 @@ def test_text_generation(self):
)
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
+ assert (
+ response.raw_prediction_response.predictions[0]
+ == _TEST_TEXT_GENERATION_PREDICTION
+ )
assert (
response.safety_attributes["Violent"]
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
@@ -622,6 +1237,7 @@ def test_text_generation_ga(self):
temperature=0,
top_p=1,
top_k=5,
+ stop_sequences=["\n"],
)
prediction_parameters = mock_predict.call_args[1]["parameters"]
@@ -629,6 +1245,7 @@ def test_text_generation_ga(self):
assert prediction_parameters["temperature"] == 0
assert prediction_parameters["topP"] == 1
assert prediction_parameters["topK"] == 5
+ assert prediction_parameters["stopSequences"] == ["\n"]
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
# Validating that unspecified parameters are not passed to the model
@@ -651,6 +1268,40 @@ def test_text_generation_ga(self):
assert "topP" not in prediction_parameters
assert "topK" not in prediction_parameters
+ def test_text_generation_model_predict_streaming(self):
+ """Tests the TextGenerationModel.predict_streaming method."""
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+ model = language_models.TextGenerationModel.from_pretrained(
+ "text-bison@001"
+ )
+
+ response_generator = (
+ gca_prediction_service.StreamingPredictResponse(
+ outputs=[_streaming_prediction.value_to_tensor(response_dict)]
+ )
+ for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING
+ )
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="server_streaming_predict",
+ return_value=response_generator,
+ ):
+ for response in model.predict_streaming(
+ "Count to 50",
+ max_output_tokens=1000,
+ temperature=0,
+ top_p=1,
+ top_k=5,
+ ):
+ assert len(response.text) > 10
+
@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],
@@ -689,23 +1340,130 @@ def test_tune_text_generation_model(
"text-bison@001"
)
- model.tune_model(
+ tuning_job_location = "europe-west4"
+ evaluation_data_uri = "gs://bucket/eval.jsonl"
+ evaluation_interval = 37
+ enable_early_stopping = True
+ tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"
+
+ tuning_job = model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
- tuning_job_location="europe-west4",
+ tuning_job_location=tuning_job_location,
tuned_model_location="us-central1",
learning_rate=0.1,
+ learning_rate_multiplier=2.0,
+ train_steps=10,
+ tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
+ evaluation_data=evaluation_data_uri,
+ evaluation_interval=evaluation_interval,
+ enable_early_stopping=enable_early_stopping,
+ tensorboard=tensorboard_name,
+ ),
)
call_kwargs = mock_pipeline_service_create.call_args[1]
pipeline_arguments = call_kwargs[
"pipeline_job"
].runtime_config.parameter_values
assert pipeline_arguments["learning_rate"] == 0.1
+ assert pipeline_arguments["learning_rate_multiplier"] == 2.0
+ assert pipeline_arguments["train_steps"] == 10
+ assert pipeline_arguments["evaluation_data_uri"] == evaluation_data_uri
+ assert pipeline_arguments["evaluation_interval"] == evaluation_interval
+ assert pipeline_arguments["enable_early_stopping"] == enable_early_stopping
+ assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
+ assert pipeline_arguments["large_model_reference"] == "text-bison@001"
+ assert (
+ call_kwargs["pipeline_job"].encryption_spec.kms_key_name
+ == _TEST_ENCRYPTION_KEY_NAME
+ )
+
+ # Testing the tuned model
+ tuned_model = tuning_job.get_tuned_model()
+ assert (
+ tuned_model._endpoint_name
+ == test_constants.EndpointConstants._TEST_ENDPOINT_NAME
+ )
+
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],
+ )
+ @pytest.mark.parametrize(
+ "mock_request_urlopen",
+ ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
+ indirect=True,
+ )
+ def test_tune_text_generation_model_ga(
+ self,
+ mock_pipeline_service_create,
+ mock_pipeline_job_get,
+ mock_pipeline_bucket_exists,
+ job_spec,
+ mock_load_yaml_and_json,
+ mock_gcs_from_string,
+ mock_gcs_upload,
+ mock_request_urlopen,
+ mock_get_tuned_model,
+ ):
+ """Tests tuning the text generation model."""
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
+ )
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+ model = language_models.TextGenerationModel.from_pretrained(
+ "text-bison@001"
+ )
+
+ tuning_job_location = "europe-west4"
+ evaluation_data_uri = "gs://bucket/eval.jsonl"
+ evaluation_interval = 37
+ enable_early_stopping = True
+ tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"
+
+ tuning_job = model.tune_model(
+ training_data=_TEST_TEXT_BISON_TRAINING_DF,
+ tuning_job_location=tuning_job_location,
+ tuned_model_location="us-central1",
+ learning_rate_multiplier=2.0,
+ train_steps=10,
+ tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
+ evaluation_data=evaluation_data_uri,
+ evaluation_interval=evaluation_interval,
+ enable_early_stopping=enable_early_stopping,
+ tensorboard=tensorboard_name,
+ ),
+ )
+ call_kwargs = mock_pipeline_service_create.call_args[1]
+ pipeline_arguments = call_kwargs[
+ "pipeline_job"
+ ].runtime_config.parameter_values
+ assert pipeline_arguments["learning_rate_multiplier"] == 2.0
+ assert pipeline_arguments["train_steps"] == 10
+ assert pipeline_arguments["evaluation_data_uri"] == evaluation_data_uri
+ assert pipeline_arguments["evaluation_interval"] == evaluation_interval
+ assert pipeline_arguments["enable_early_stopping"] == enable_early_stopping
+ assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
assert pipeline_arguments["large_model_reference"] == "text-bison@001"
assert (
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
== _TEST_ENCRYPTION_KEY_NAME
)
+ # Testing the tuned model
+ tuned_model = tuning_job.get_tuned_model()
+ assert (
+ tuned_model._endpoint_name
+ == test_constants.EndpointConstants._TEST_ENDPOINT_NAME
+ )
+
@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
@@ -738,16 +1496,72 @@ def test_tune_chat_model(
):
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")
- model.tune_model(
+ default_context = "Default context"
+ tuning_job = model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
+ default_context=default_context,
)
call_kwargs = mock_pipeline_service_create.call_args[1]
pipeline_arguments = call_kwargs[
"pipeline_job"
].runtime_config.parameter_values
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
+ assert pipeline_arguments["default_context"] == default_context
+
+ # Testing the tuned model
+ tuned_model = tuning_job.get_tuned_model()
+ assert (
+ tuned_model._endpoint_name
+ == test_constants.EndpointConstants._TEST_ENDPOINT_NAME
+ )
+
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_PIPELINE_SPEC_JSON],
+ )
+ @pytest.mark.parametrize(
+ "mock_request_urlopen",
+ ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
+ indirect=True,
+ )
+ def test_tune_code_generation_model(
+ self,
+ mock_pipeline_service_create,
+ mock_pipeline_job_get,
+ mock_pipeline_bucket_exists,
+ job_spec,
+ mock_load_yaml_and_json,
+ mock_gcs_from_string,
+ mock_gcs_upload,
+ mock_request_urlopen,
+ mock_get_tuned_model,
+ ):
+ """Tests tuning a code generation model."""
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+ model = preview_language_models.CodeGenerationModel.from_pretrained(
+ "code-bison@001"
+ )
+ # The tune_model call needs to be inside the PublisherModel mock
+ # since it gets a new PublisherModel when tuning completes.
+ model.tune_model(
+ training_data=_TEST_TEXT_BISON_TRAINING_DF,
+ tuning_job_location="europe-west4",
+ tuned_model_location="us-central1",
+ )
+ call_kwargs = mock_pipeline_service_create.call_args[1]
+ pipeline_arguments = call_kwargs[
+ "pipeline_job"
+ ].runtime_config.parameter_values
+ assert pipeline_arguments["large_model_reference"] == "code-bison@001"
@pytest.mark.parametrize(
"job_spec",
@@ -1066,16 +1880,19 @@ def test_chat_ga(self):
chat_max_output_tokens = 100
chat_top_k = 1
chat_top_p = 0.1
+ stop_sequences = ["\n"]
message_temperature = 0.2
message_max_output_tokens = 200
message_top_k = 2
message_top_p = 0.2
+ message_stop_sequences = ["# %%"]
chat2 = model.start_chat(
temperature=chat_temperature,
max_output_tokens=chat_max_output_tokens,
top_k=chat_top_k,
top_p=chat_top_p,
+ stop_sequences=stop_sequences,
)
gca_predict_response3 = gca_prediction_service.PredictResponse()
@@ -1092,6 +1909,7 @@ def test_chat_ga(self):
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
assert prediction_parameters["topK"] == chat_top_k
assert prediction_parameters["topP"] == chat_top_p
+ assert prediction_parameters["stopSequences"] == stop_sequences
chat2.send_message(
"Are my favorite movies based on a book series?",
@@ -1099,14 +1917,96 @@ def test_chat_ga(self):
max_output_tokens=message_max_output_tokens,
top_k=message_top_k,
top_p=message_top_p,
+ stop_sequences=message_stop_sequences,
)
prediction_parameters = mock_predict3.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == message_temperature
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
assert prediction_parameters["topK"] == message_top_k
assert prediction_parameters["topP"] == message_top_p
+ assert prediction_parameters["stopSequences"] == message_stop_sequences
- def test_code_chat(self):
+ def test_chat_model_send_message_streaming(self):
+ """Tests the chat generation model."""
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _CHAT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+ model = language_models.ChatModel.from_pretrained("chat-bison@001")
+
+ chat = model.start_chat(
+ context="""
+ My name is Ned.
+ You are my personal assistant.
+ My favorite movies are Lord of the Rings and Hobbit.
+ """,
+ examples=[
+ language_models.InputOutputTextPair(
+ input_text="Who do you work for?",
+ output_text="I work for Ned.",
+ ),
+ language_models.InputOutputTextPair(
+ input_text="What do I like?",
+ output_text="Ned likes watching movies.",
+ ),
+ ],
+ message_history=[
+ language_models.ChatMessage(
+ author=preview_language_models.ChatSession.USER_AUTHOR,
+ content="Question 1?",
+ ),
+ language_models.ChatMessage(
+ author=preview_language_models.ChatSession.MODEL_AUTHOR,
+ content="Answer 1.",
+ ),
+ ],
+ temperature=0.0,
+ )
+
+ # Using list instead of a generator so that it can be reused.
+ response_generator = [
+ gca_prediction_service.StreamingPredictResponse(
+ outputs=[_streaming_prediction.value_to_tensor(response_dict)]
+ )
+ for response_dict in _TEST_CHAT_PREDICTION_STREAMING
+ ]
+
+ message_temperature = 0.2
+ message_max_output_tokens = 200
+ message_top_k = 2
+ message_top_p = 0.2
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="server_streaming_predict",
+ return_value=response_generator,
+ ):
+ message_text1 = "Are my favorite movies based on a book series?"
+
+ for idx, response in enumerate(
+ chat.send_message_streaming(
+ message=message_text1,
+ max_output_tokens=message_max_output_tokens,
+ temperature=message_temperature,
+ top_k=message_top_k,
+ top_p=message_top_p,
+ )
+ ):
+ assert len(response.text) > 10
+ # New messages are not added until the response is fully read
+ if idx + 1 < len(response_generator):
+ assert len(chat.message_history) == 2
+
+ # New messages are only added after the response is fully read
+ assert len(chat.message_history) == 4
+ assert chat.message_history[2].author == chat.USER_AUTHOR
+ assert chat.message_history[2].content == message_text1
+ assert chat.message_history[3].author == chat.MODEL_AUTHOR
+
+ def test_code_chat(self):
"""Tests the code chat model."""
aiplatform.init(
project=_TEST_PROJECT,
@@ -1202,6 +2102,51 @@ def test_code_chat(self):
assert prediction_parameters["temperature"] == message_temperature
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
+ def test_code_chat_model_send_message_streaming(self):
+ """Tests the chat generation model."""
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ )
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _CODECHAT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+ model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")
+
+ chat = model.start_chat(temperature=0.0)
+
+ # Using list instead of a generator so that it can be reused.
+ response_generator = [
+ gca_prediction_service.StreamingPredictResponse(
+ outputs=[_streaming_prediction.value_to_tensor(response_dict)]
+ )
+ for response_dict in _TEST_CHAT_PREDICTION_STREAMING
+ ]
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="server_streaming_predict",
+ return_value=response_generator,
+ ):
+ message_text1 = (
+ "Please help write a function to calculate the max of two numbers"
+ )
+ # New messages are not added until the response is fully read
+ assert not chat.message_history
+ for response in chat.send_message_streaming(message_text1):
+ assert len(response.text) > 10
+ # New messages are only added after the response is fully read
+ assert chat.message_history
+
+ assert len(chat.message_history) == 2
+ assert chat.message_history[0].author == chat.USER_AUTHOR
+ assert chat.message_history[0].content == message_text1
+ assert chat.message_history[1].author == chat.MODEL_AUTHOR
+
def test_code_generation(self):
"""Tests code generation with the code generation model."""
aiplatform.init(
@@ -1245,6 +2190,7 @@ def test_code_generation(self):
default_max_output_tokens = (
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
)
+ stop_sequences = ["\n"]
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
@@ -1255,10 +2201,12 @@ def test_code_generation(self):
prefix="Write a function that checks if a year is a leap year.",
max_output_tokens=predict_max_output_tokens,
temperature=predict_temperature,
+ stop_sequences=stop_sequences,
)
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == predict_temperature
assert prediction_parameters["maxOutputTokens"] == predict_max_output_tokens
+ assert prediction_parameters["stopSequences"] == stop_sequences
model.predict(
prefix="Write a function that checks if a year is a leap year.",
@@ -1332,6 +2280,39 @@ def test_code_completion(self):
assert "temperature" not in prediction_parameters
assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens
+ def test_code_generation_model_predict_streaming(self):
+ """Tests the TextGenerationModel.predict_streaming method."""
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+ model = language_models.CodeGenerationModel.from_pretrained(
+ "code-bison@001"
+ )
+
+ response_generator = (
+ gca_prediction_service.StreamingPredictResponse(
+ outputs=[_streaming_prediction.value_to_tensor(response_dict)]
+ )
+ for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING
+ )
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="server_streaming_predict",
+ return_value=response_generator,
+ ):
+ for response in model.predict_streaming(
+ prefix="def reverse_string(s):",
+ suffix=" return s",
+ max_output_tokens=1000,
+ temperature=0,
+ ):
+ assert len(response.text) > 10
+
def test_text_embedding(self):
"""Tests the text embedding model."""
aiplatform.init(
@@ -1356,18 +2337,57 @@ def test_text_embedding(self):
gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)
+ gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)
+ expected_embedding = _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
- ):
- embeddings = model.get_embeddings(["What is life?"])
+ ) as mock_predict:
+ embeddings = model.get_embeddings(
+ [
+ "What is life?",
+ language_models.TextEmbeddingInput(
+ text="Foo",
+ task_type="RETRIEVAL_DOCUMENT",
+ title="Bar",
+ ),
+ language_models.TextEmbeddingInput(
+ text="Baz",
+ task_type="CLASSIFICATION",
+ ),
+ ],
+ auto_truncate=False,
+ )
+ prediction_instances = mock_predict.call_args[1]["instances"]
+ assert prediction_instances == [
+ {"content": "What is life?"},
+ {
+ "content": "Foo",
+ "taskType": "RETRIEVAL_DOCUMENT",
+ "title": "Bar",
+ },
+ {
+ "content": "Baz",
+ "taskType": "CLASSIFICATION",
+ },
+ ]
+ prediction_parameters = mock_predict.call_args[1]["parameters"]
+ assert not prediction_parameters["autoTruncate"]
assert embeddings
for embedding in embeddings:
vector = embedding.values
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
- assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]
+ assert vector == expected_embedding["values"]
+ assert (
+ embedding.statistics.token_count
+ == expected_embedding["statistics"]["token_count"]
+ )
+ assert (
+ embedding.statistics.truncated
+ == expected_embedding["statistics"]["truncated"]
+ )
def test_text_embedding_ga(self):
"""Tests the text embedding model."""
@@ -1473,3 +2493,330 @@ def test_batch_prediction_for_text_embedding(self):
gcs_destination_prefix="gs://test-bucket/results/",
model_parameters={},
)
+
+
+# TODO (b/285946649): add more test coverage before public preview release
+@pytest.mark.usefixtures("google_auth_mock")
+class TestLanguageModelEvaluation:
+ @pytest.mark.usefixtures(
+ "get_model_with_tuned_version_label_mock",
+ "get_endpoint_with_models_mock",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_EVAL_PIPELINE_SPEC_JSON, _TEST_EVAL_PIPELINE_JOB],
+ )
+ @pytest.mark.parametrize(
+ "mock_request_urlopen_eval",
+ ["https://us-kfp.pkg.dev/proj/repo/pack/latest"],
+ indirect=True,
+ )
+ def test_model_evaluation_text_generation_task_with_gcs_input(
+ self,
+ job_spec,
+ mock_pipeline_service_create_eval,
+ mock_pipeline_job_get_eval,
+ mock_successfully_completed_eval_job,
+ mock_pipeline_bucket_exists,
+ mock_load_yaml_and_json,
+ mock_request_urlopen_eval,
+ ):
+
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
+
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+
+ my_model = preview_language_models.TextGenerationModel.get_tuned_model(
+ test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
+ )
+
+ eval_metrics = my_model.evaluate(
+ task_spec=preview_language_models.EvaluationTextGenerationSpec(
+ ground_truth_data="gs://my-bucket/ground-truth.jsonl",
+ ),
+ )
+
+ assert isinstance(eval_metrics, preview_language_models.EvaluationMetric)
+ assert eval_metrics.bleu == _TEST_TEXT_GENERATION_METRICS["bleu"]
+
+ @pytest.mark.usefixtures(
+ "get_model_with_tuned_version_label_mock",
+ "get_endpoint_with_models_mock",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_EVAL_PIPELINE_SPEC_JSON, _TEST_EVAL_PIPELINE_JOB],
+ )
+ def test_populate_eval_template_params(
+ self,
+ job_spec,
+ mock_pipeline_service_create,
+ mock_model_evaluate,
+ mock_pipeline_job_get,
+ mock_successfully_completed_eval_job,
+ mock_pipeline_bucket_exists,
+ mock_load_yaml_and_json,
+ ):
+
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
+
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+
+ my_model = preview_language_models.TextGenerationModel.get_tuned_model(
+ test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
+ )
+
+ task_spec = preview_language_models.EvaluationTextGenerationSpec(
+ ground_truth_data="gs://my-bucket/ground-truth.jsonl",
+ )
+
+ formatted_template_params = (
+ _evaluatable_language_models._populate_eval_template_params(
+ task_spec=task_spec, model_name=my_model._model_resource_name
+ )
+ )
+
+ assert (
+ "batch_predict_gcs_destination_output_uri" in formatted_template_params
+ )
+ assert "model_name" in formatted_template_params
+ assert "evaluation_task" in formatted_template_params
+
+ # This should only be in the classification task pipeline template
+ assert "evaluation_class_labels" not in formatted_template_params
+ assert "target_column_name" not in formatted_template_params
+
+ @pytest.mark.usefixtures(
+ "get_model_with_tuned_version_label_mock",
+ "get_endpoint_with_models_mock",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_EVAL_PIPELINE_SPEC_JSON, _TEST_EVAL_PIPELINE_JOB],
+ )
+ def test_populate_template_params_for_classification_task(
+ self,
+ job_spec,
+ mock_pipeline_service_create,
+ mock_model_evaluate,
+ mock_pipeline_job_get,
+ mock_successfully_completed_eval_job,
+ mock_pipeline_bucket_exists,
+ mock_load_yaml_and_json,
+ ):
+
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
+
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+
+ my_model = preview_language_models.TextGenerationModel.get_tuned_model(
+ test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
+ )
+
+ task_spec = preview_language_models.EvaluationTextClassificationSpec(
+ ground_truth_data="gs://my-bucket/ground-truth.jsonl",
+ target_column_name="test_targ_name",
+ class_names=["test_class_name_1", "test_class_name_2"],
+ )
+
+ formatted_template_params = (
+ _evaluatable_language_models._populate_eval_template_params(
+ task_spec=task_spec, model_name=my_model._model_resource_name
+ )
+ )
+
+ assert "evaluation_class_labels" in formatted_template_params
+ assert "target_field_name" in formatted_template_params
+
+ @pytest.mark.usefixtures(
+ "get_model_with_tuned_version_label_mock",
+ "get_endpoint_with_models_mock",
+ "mock_storage_blob_upload_from_filename",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_EVAL_PIPELINE_SPEC_JSON, _TEST_EVAL_PIPELINE_JOB],
+ )
+ def test_populate_template_params_with_dataframe_input(
+ self,
+ job_spec,
+ mock_pipeline_service_create,
+ mock_pipeline_job_get,
+ mock_successfully_completed_eval_job,
+ mock_pipeline_bucket_exists,
+ mock_load_yaml_and_json,
+ ):
+
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
+
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+
+ my_model = preview_language_models.TextGenerationModel.get_tuned_model(
+ test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
+ )
+
+ task_spec = preview_language_models.EvaluationTextGenerationSpec(
+ ground_truth_data=_TEST_EVAL_DATA_DF,
+ )
+
+ formatted_template_params = (
+ _evaluatable_language_models._populate_eval_template_params(
+ task_spec=task_spec, model_name=my_model._model_resource_name
+ )
+ )
+
+ # The utility method should not modify task_spec
+ assert isinstance(task_spec.ground_truth_data, pd.DataFrame)
+
+ assert (
+ "batch_predict_gcs_destination_output_uri" in formatted_template_params
+ )
+ assert "model_name" in formatted_template_params
+ assert "evaluation_task" in formatted_template_params
+
+ # This should only be in the classification task pipeline template
+ assert "evaluation_class_labels" not in formatted_template_params
+ assert "target_column_name" not in formatted_template_params
+
+ def test_evaluate_raises_on_ga_language_model(
+ self,
+ ):
+
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
+
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+ model = language_models.TextGenerationModel.from_pretrained(
+ "text-bison@001"
+ )
+
+ with pytest.raises(AttributeError):
+ model.evaluate()
+
+ @pytest.mark.usefixtures(
+ "get_endpoint_with_models_mock",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_EVAL_PIPELINE_SPEC_JSON, _TEST_EVAL_PIPELINE_JOB],
+ )
+ @pytest.mark.parametrize(
+ "mock_request_urlopen_eval",
+ ["https://us-kfp.pkg.dev/proj/repo/pack/latest"],
+ indirect=True,
+ )
+ def test_model_evaluation_text_generation_task_on_base_model(
+ self,
+ job_spec,
+ mock_pipeline_service_create_eval,
+ mock_pipeline_job_get_eval,
+ mock_successfully_completed_eval_job,
+ mock_pipeline_bucket_exists,
+ mock_load_yaml_and_json,
+ mock_request_urlopen_eval,
+ ):
+
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
+
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+
+ my_model = preview_language_models.TextGenerationModel.from_pretrained(
+ "text-bison@001"
+ )
+
+ eval_metrics = my_model.evaluate(
+ task_spec=preview_language_models.EvaluationTextGenerationSpec(
+ ground_truth_data="gs://my-bucket/ground-truth.jsonl",
+ ),
+ )
+
+ assert isinstance(eval_metrics, preview_language_models.EvaluationMetric)
+
+ @pytest.mark.usefixtures(
+ "get_endpoint_with_models_mock",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [
+ _TEST_EVAL_CLASSIFICATION_PIPELINE_SPEC_JSON,
+ _TEST_EVAL_CLASSIFICATION_PIPELINE_JOB,
+ ],
+ )
+ @pytest.mark.parametrize(
+ "mock_request_urlopen_eval_classification",
+ ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
+ indirect=True,
+ )
+ def test_model_evaluation_text_classification_base_model_only_summary_metrics(
+ self,
+ job_spec,
+ mock_pipeline_service_create_eval_classification,
+ mock_pipeline_job_get_eval_classification,
+ mock_successfully_completed_eval_classification_job,
+ mock_pipeline_bucket_exists,
+ mock_load_yaml_and_json,
+ mock_request_urlopen_eval_classification,
+ ):
+
+ aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
+
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
+ ),
+ ):
+ my_model = preview_language_models.TextGenerationModel.from_pretrained(
+ "text-bison@001"
+ )
+
+ eval_metrics = my_model.evaluate(
+ task_spec=preview_language_models.EvaluationTextClassificationSpec(
+ ground_truth_data="gs://my-bucket/ground-truth.jsonl",
+ target_column_name="test_targ_name",
+ class_names=["test_class_name_1", "test_class_name_2"],
+ )
+ )
+
+ assert isinstance(
+ eval_metrics,
+ preview_language_models.EvaluationClassificationMetric,
+ )
+ assert eval_metrics.confidenceMetrics is None
+ assert eval_metrics.auPrc == _TEST_TEXT_CLASSIFICATION_METRICS["auPrc"]
diff --git a/tests/unit/aiplatform/test_pipeline_job_schedules.py b/tests/unit/aiplatform/test_pipeline_job_schedules.py
index 5895efe148..61d19b41fe 100644
--- a/tests/unit/aiplatform/test_pipeline_job_schedules.py
+++ b/tests/unit/aiplatform/test_pipeline_job_schedules.py
@@ -29,21 +29,25 @@
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import pipeline_jobs
+from google.cloud.aiplatform.constants import pipeline as pipeline_constants
from google.cloud.aiplatform.compat.services import (
pipeline_service_client,
- schedule_service_client_v1beta1 as schedule_service_client,
+ schedule_service_client,
)
from google.cloud.aiplatform.compat.types import (
- context_v1beta1 as gca_context,
- pipeline_job_v1beta1 as gca_pipeline_job,
- pipeline_state_v1beta1 as gca_pipeline_state,
- schedule_v1beta1 as gca_schedule,
+ context as gca_context,
+ pipeline_job as gca_pipeline_job,
+ pipeline_state as gca_pipeline_state,
+ schedule as gca_schedule,
+)
+from google.cloud.aiplatform import (
+ pipeline_job_schedules,
)
from google.cloud.aiplatform.preview.pipelinejob import (
pipeline_jobs as preview_pipeline_jobs,
)
from google.cloud.aiplatform.preview.pipelinejobschedule import (
- pipeline_job_schedules,
+ pipeline_job_schedules as preview_pipeline_job_schedules,
)
from google.cloud.aiplatform.utils import gcs_utils
import pytest
@@ -51,6 +55,7 @@
from google.protobuf import struct_pb2
from google.protobuf import json_format
+from google.protobuf import field_mask_pb2 as field_mask
_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
@@ -63,11 +68,11 @@
_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME = "sample-pipeline-job-schedule-display-name"
_TEST_PIPELINE_JOB_SCHEDULE_ID = "sample-test-schedule-20230417"
_TEST_PIPELINE_JOB_SCHEDULE_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/schedules/{_TEST_PIPELINE_JOB_SCHEDULE_ID}"
-_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION = "* * * * *"
+_TEST_PIPELINE_JOB_SCHEDULE_CRON = "* * * * *"
_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT = 1
_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 2
-_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION = "1 1 1 1 1"
+_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON = "1 1 1 1 1"
_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 5
_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
@@ -78,6 +83,9 @@
f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_SCHEDULE_ID}"
)
+_TEST_PIPELINE_JOB_LIST_READ_MASK = field_mask.FieldMask(
+ paths=pipeline_constants._READ_MASK_FIELDS
+)
_TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param": "hello"}
_TEST_PIPELINE_PARAMETER_VALUES = {
"string_param": "hello world",
@@ -229,7 +237,7 @@ def mock_schedule_service_create():
name=_TEST_PIPELINE_JOB_SCHEDULE_NAME,
state=gca_schedule.Schedule.State.COMPLETED,
create_time=_TEST_PIPELINE_CREATE_TIME,
- cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST,
@@ -267,7 +275,7 @@ def make_schedule(state):
name=_TEST_PIPELINE_JOB_SCHEDULE_NAME,
state=state,
create_time=_TEST_PIPELINE_CREATE_TIME,
- cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST,
@@ -381,7 +389,7 @@ def mock_schedule_service_update():
name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
state=gca_schedule.Schedule.State.COMPLETED,
create_time=_TEST_PIPELINE_CREATE_TIME,
- cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST,
@@ -423,6 +431,89 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
+ )
+ def test_call_preview_schedule_service_create(
+ self,
+ mock_schedule_service_create,
+ mock_schedule_service_get,
+ mock_schedule_bucket_exists,
+ job_spec,
+ mock_load_yaml_and_json,
+ ):
+ """Creates a PipelineJobSchedule.
+
+ Creates PipelineJob with template stored in GCS bucket.
+ """
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ staging_bucket=_TEST_GCS_BUCKET_NAME,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ )
+
+ job = pipeline_jobs.PipelineJob(
+ display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
+ template_path=_TEST_TEMPLATE_PATH,
+ parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
+ input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
+ enable_caching=True,
+ )
+
+ pipeline_job_schedule = preview_pipeline_job_schedules.PipelineJobSchedule(
+ pipeline_job=job,
+ display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
+ )
+
+ pipeline_job_schedule.create(
+ cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
+ max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
+ max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ network=_TEST_NETWORK,
+ create_request_timeout=None,
+ )
+
+ expected_runtime_config_dict = {
+ "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
+ "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
+ "inputArtifacts": {"vertex_model": {"artifactId": "456"}},
+ }
+ runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
+ json_format.ParseDict(expected_runtime_config_dict, runtime_config)
+
+ job_spec = yaml.safe_load(job_spec)
+ pipeline_spec = job_spec.get("pipelineSpec") or job_spec
+
+ # Construct expected request
+ expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
+ display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
+ max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
+ max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
+ create_pipeline_job_request={
+ "parent": _TEST_PARENT,
+ "pipeline_job": {
+ "runtime_config": runtime_config,
+ "pipeline_spec": dict_to_struct(pipeline_spec),
+ "service_account": _TEST_SERVICE_ACCOUNT,
+ "network": _TEST_NETWORK,
+ },
+ },
+ )
+
+ mock_schedule_service_create.assert_called_once_with(
+ parent=_TEST_PARENT,
+ schedule=expected_gapic_pipeline_job_schedule,
+ timeout=None,
+ )
+
+ assert pipeline_job_schedule._gca_resource == make_schedule(
+ gca_schedule.Schedule.State.COMPLETED
+ )
+
@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
@@ -460,7 +551,7 @@ def test_call_schedule_service_create(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -482,7 +573,7 @@ def test_call_schedule_service_create(
# Construct expected request
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
- cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request={
@@ -544,7 +635,7 @@ def test_call_schedule_service_create_with_different_timezone(
test_pipeline_job_schedule_cron_tz_expression = "TZ=America/New_York * * * * *"
pipeline_job_schedule.create(
- cron_expression=test_pipeline_job_schedule_cron_tz_expression,
+ cron=test_pipeline_job_schedule_cron_tz_expression,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -627,7 +718,7 @@ def test_call_schedule_service_create_artifact_registry(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -648,7 +739,7 @@ def test_call_schedule_service_create_artifact_registry(
# Construct expected request
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
- cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request={
@@ -710,7 +801,7 @@ def test_call_schedule_service_create_https(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -731,7 +822,7 @@ def test_call_schedule_service_create_https(
# Construct expected request
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
- cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request={
@@ -792,7 +883,7 @@ def test_call_schedule_service_create_with_timeout(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -813,7 +904,7 @@ def test_call_schedule_service_create_with_timeout(
# Construct expected request
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
- cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request={
@@ -873,7 +964,7 @@ def test_call_schedule_service_create_with_timeout_not_explicitly_set(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -893,7 +984,7 @@ def test_call_schedule_service_create_with_timeout_not_explicitly_set(
# Construct expected request
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
- cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request={
@@ -917,7 +1008,7 @@ def test_call_schedule_service_create_with_timeout_not_explicitly_set(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
- def test_call_pipeline_job_create_schedule(
+ def test_call_preview_pipeline_job_create_schedule(
self,
mock_schedule_service_create,
mock_schedule_service_get,
@@ -942,7 +1033,7 @@ def test_call_pipeline_job_create_schedule(
pipeline_job_schedule = job.create_schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -961,7 +1052,79 @@ def test_call_pipeline_job_create_schedule(
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
- cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
+ max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
+ max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
+ create_pipeline_job_request={
+ "parent": _TEST_PARENT,
+ "pipeline_job": {
+ "runtime_config": runtime_config,
+ "pipeline_spec": dict_to_struct(pipeline_spec),
+ "service_account": _TEST_SERVICE_ACCOUNT,
+ "network": _TEST_NETWORK,
+ },
+ },
+ )
+
+ mock_schedule_service_create.assert_called_once_with(
+ parent=_TEST_PARENT,
+ schedule=expected_gapic_pipeline_job_schedule,
+ timeout=None,
+ )
+
+ assert pipeline_job_schedule._gca_resource == make_schedule(
+ gca_schedule.Schedule.State.COMPLETED
+ )
+
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
+ )
+ def test_call_pipeline_job_create_schedule(
+ self,
+ mock_schedule_service_create,
+ mock_schedule_service_get,
+ job_spec,
+ mock_load_yaml_and_json,
+ ):
+ """Creates a PipelineJobSchedule via PipelineJob.create_schedule()."""
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ staging_bucket=_TEST_GCS_BUCKET_NAME,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ )
+
+ job = pipeline_jobs.PipelineJob(
+ display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
+ template_path=_TEST_TEMPLATE_PATH,
+ parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
+ input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
+ enable_caching=True,
+ )
+
+ pipeline_job_schedule = job.create_schedule(
+ display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
+ max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
+ max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ network=_TEST_NETWORK,
+ )
+
+ expected_runtime_config_dict = {
+ "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
+ "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
+ "inputArtifacts": {"vertex_model": {"artifactId": "456"}},
+ }
+ runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
+ json_format.ParseDict(expected_runtime_config_dict, runtime_config)
+
+ job_spec = yaml.safe_load(job_spec)
+ pipeline_spec = job_spec.get("pipelineSpec") or job_spec
+ expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
+ display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request={
@@ -1029,7 +1192,7 @@ def test_done_method_schedule_service(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
)
@@ -1079,7 +1242,7 @@ def test_pause_resume_schedule_service(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
)
@@ -1127,7 +1290,7 @@ def test_list_schedules(self, mock_schedule_service_list, mock_load_yaml_and_jso
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -1141,6 +1304,111 @@ def test_list_schedules(self, mock_schedule_service_list, mock_load_yaml_and_jso
request={"parent": _TEST_PARENT}
)
+ @pytest.mark.usefixtures(
+ "mock_schedule_service_create",
+ "mock_schedule_service_get",
+ "mock_schedule_bucket_exists",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
+ )
+ def test_preview_list_schedule_jobs(
+ self,
+ mock_pipeline_service_list,
+ mock_load_yaml_and_json,
+ ):
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ staging_bucket=_TEST_GCS_BUCKET_NAME,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ )
+
+ job = pipeline_jobs.PipelineJob(
+ display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
+ template_path=_TEST_TEMPLATE_PATH,
+ parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
+ input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
+ enable_caching=True,
+ )
+
+ pipeline_job_schedule = preview_pipeline_job_schedules.PipelineJobSchedule(
+ pipeline_job=job,
+ display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
+ )
+
+ pipeline_job_schedule.create(
+ cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
+ max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
+ max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ network=_TEST_NETWORK,
+ create_request_timeout=None,
+ )
+
+ pipeline_job_schedule.list_jobs()
+
+ mock_pipeline_service_list.assert_called_once_with(
+ request={
+ "parent": _TEST_PARENT,
+ "filter": f"schedule_name={_TEST_PIPELINE_JOB_SCHEDULE_NAME}",
+ },
+ )
+
+ @pytest.mark.usefixtures(
+ "mock_schedule_service_create",
+ "mock_schedule_service_get",
+ "mock_schedule_bucket_exists",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
+ )
+ def test_preview_list_schedule_jobs_with_read_mask(
+ self,
+ mock_pipeline_service_list,
+ mock_load_yaml_and_json,
+ ):
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ staging_bucket=_TEST_GCS_BUCKET_NAME,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ )
+
+ job = pipeline_jobs.PipelineJob(
+ display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
+ template_path=_TEST_TEMPLATE_PATH,
+ parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
+ input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
+ enable_caching=True,
+ )
+
+ pipeline_job_schedule = preview_pipeline_job_schedules.PipelineJobSchedule(
+ pipeline_job=job,
+ display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
+ )
+
+ pipeline_job_schedule.create(
+ cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
+ max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
+ max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ network=_TEST_NETWORK,
+ create_request_timeout=None,
+ )
+
+ pipeline_job_schedule.list_jobs(enable_simple_view=True)
+
+ mock_pipeline_service_list.assert_called_once_with(
+ request={
+ "parent": _TEST_PARENT,
+ "read_mask": _TEST_PIPELINE_JOB_LIST_READ_MASK,
+ "filter": f"schedule_name={_TEST_PIPELINE_JOB_SCHEDULE_NAME}",
+ },
+ )
+
@pytest.mark.usefixtures(
"mock_schedule_service_create",
"mock_schedule_service_get",
@@ -1176,7 +1444,59 @@ def test_list_schedule_jobs(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
+ max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
+ max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ network=_TEST_NETWORK,
+ create_request_timeout=None,
+ )
+
+ pipeline_job_schedule.list_jobs(enable_simple_view=False)
+
+ mock_pipeline_service_list.assert_called_once_with(
+ request={
+ "parent": _TEST_PARENT,
+ "filter": f"schedule_name={_TEST_PIPELINE_JOB_SCHEDULE_NAME}",
+ },
+ )
+
+ @pytest.mark.usefixtures(
+ "mock_schedule_service_create",
+ "mock_schedule_service_get",
+ "mock_schedule_bucket_exists",
+ )
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
+ )
+ def test_list_schedule_jobs_with_read_mask(
+ self,
+ mock_pipeline_service_list,
+ mock_load_yaml_and_json,
+ ):
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ staging_bucket=_TEST_GCS_BUCKET_NAME,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ )
+
+ job = pipeline_jobs.PipelineJob(
+ display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
+ template_path=_TEST_TEMPLATE_PATH,
+ parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
+ input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
+ enable_caching=True,
+ )
+
+ pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
+ pipeline_job=job,
+ display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
+ )
+
+ pipeline_job_schedule.create(
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -1189,6 +1509,7 @@ def test_list_schedule_jobs(
mock_pipeline_service_list.assert_called_once_with(
request={
"parent": _TEST_PARENT,
+ "read_mask": _TEST_PIPELINE_JOB_LIST_READ_MASK,
"filter": f"schedule_name={_TEST_PIPELINE_JOB_SCHEDULE_NAME}",
},
)
@@ -1292,7 +1613,7 @@ def test_call_schedule_service_update(
):
"""Updates a PipelineJobSchedule.
- Updates cron_expression and max_run_count.
+ Updates cron and max_run_count.
"""
aiplatform.init(
project=_TEST_PROJECT,
@@ -1315,7 +1636,7 @@ def test_call_schedule_service_update(
)
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -1324,7 +1645,7 @@ def test_call_schedule_service_update(
)
pipeline_job_schedule.update(
- cron_expression=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON,
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
)
@@ -1332,7 +1653,7 @@ def test_call_schedule_service_update(
name=_TEST_PIPELINE_JOB_SCHEDULE_NAME,
state=gca_schedule.Schedule.State.COMPLETED,
create_time=_TEST_PIPELINE_CREATE_TIME,
- cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST,
@@ -1380,7 +1701,7 @@ def test_call_schedule_service_update_before_create(
with pytest.raises(RuntimeError) as e:
pipeline_job_schedule.update(
- cron_expression=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON,
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
)
@@ -1430,7 +1751,7 @@ def test_get_max_run_count_before_create(
assert e.match(regexp=r"PipelineJobSchedule resource has not been created.")
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -1444,7 +1765,7 @@ def test_get_max_run_count_before_create(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
- def test_get_cron_expression_before_create(
+ def test_get_cron_before_create(
self,
mock_schedule_service_create,
mock_schedule_service_get,
@@ -1476,13 +1797,65 @@ def test_get_cron_expression_before_create(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
)
+ with pytest.raises(RuntimeError) as e:
+ pipeline_job_schedule.cron
+
+ assert e.match(regexp=r"PipelineJobSchedule resource has not been created.")
+
+ pipeline_job_schedule.create(
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
+ max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
+ max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ network=_TEST_NETWORK,
+ create_request_timeout=None,
+ )
+
+ pipeline_job_schedule.cron
+
+ @pytest.mark.parametrize(
+ "job_spec",
+ [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
+ )
+ def test_get_cron_expression_before_create(
+ self,
+ mock_schedule_service_create,
+ mock_schedule_service_get,
+ mock_schedule_bucket_exists,
+ job_spec,
+ mock_load_yaml_and_json,
+ ):
+ """Gets the PipelineJobSchedule cron expression before creating.
+
+ Raises error because PipelineJobSchedule should be created first.
+ """
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ staging_bucket=_TEST_GCS_BUCKET_NAME,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ )
+
+ job = pipeline_jobs.PipelineJob(
+ display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
+ template_path=_TEST_TEMPLATE_PATH,
+ parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
+ input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
+ enable_caching=True,
+ )
+
+ pipeline_job_schedule = preview_pipeline_job_schedules.PipelineJobSchedule(
+ pipeline_job=job,
+ display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
+ )
+
with pytest.raises(RuntimeError) as e:
pipeline_job_schedule.cron_expression
assert e.match(regexp=r"PipelineJobSchedule resource has not been created.")
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -1534,7 +1907,7 @@ def test_get_max_concurrent_run_count_before_create(
assert e.match(regexp=r"PipelineJobSchedule resource has not been created.")
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
@@ -1586,7 +1959,7 @@ def test_get_allow_queueing_before_create(
assert e.match(regexp=r"PipelineJobSchedule resource has not been created.")
pipeline_job_schedule.create(
- cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
+ cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py
index 80a5169e8c..a35e644b46 100644
--- a/tests/unit/aiplatform/test_training_jobs.py
+++ b/tests/unit/aiplatform/test_training_jobs.py
@@ -233,6 +233,7 @@
test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART
)
+_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES
_TEST_ENABLE_WEB_ACCESS = test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS
_TEST_ENABLE_DASHBOARD_ACCESS = True
_TEST_WEB_ACCESS_URIS = test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS
@@ -278,6 +279,7 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"):
custom_job_proto.job_spec.scheduling.restart_job_on_worker_restart = (
_TEST_RESTART_JOB_ON_WORKER_RESTART
)
+ custom_job_proto.job_spec.scheduling.disable_retries = _TEST_DISABLE_RETRIES
return custom_job_proto
@@ -730,6 +732,7 @@ def make_training_pipeline_with_scheduling(state):
training_task_inputs={
"timeout": f"{_TEST_TIMEOUT}s",
"restart_job_on_worker_restart": _TEST_RESTART_JOB_ON_WORKER_RESTART,
+ "disable_retries": _TEST_DISABLE_RETRIES,
},
)
if state == gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING:
@@ -2251,6 +2254,7 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
if not sync:
@@ -2269,6 +2273,10 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
job._gca_resource.training_task_inputs["restart_job_on_worker_restart"]
== _TEST_RESTART_JOB_ON_WORKER_RESTART
)
+ assert (
+ job._gca_resource.training_task_inputs["disable_retries"]
+ == _TEST_DISABLE_RETRIES
+ )
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@@ -4250,6 +4258,7 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
if not sync:
@@ -4268,6 +4277,10 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
job._gca_resource.training_task_inputs["restart_job_on_worker_restart"]
== _TEST_RESTART_JOB_ON_WORKER_RESTART
)
+ assert (
+ job._gca_resource.training_task_inputs["disable_retries"]
+ == _TEST_DISABLE_RETRIES
+ )
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@@ -6525,6 +6538,7 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
create_request_timeout=None,
+ disable_retries=_TEST_DISABLE_RETRIES,
)
if not sync:
@@ -6543,6 +6557,10 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
job._gca_resource.training_task_inputs["restart_job_on_worker_restart"]
== _TEST_RESTART_JOB_ON_WORKER_RESTART
)
+ assert (
+ job._gca_resource.training_task_inputs["disable_retries"]
+ == _TEST_DISABLE_RETRIES
+ )
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
diff --git a/tests/unit/aiplatform/test_vision_models.py b/tests/unit/aiplatform/test_vision_models.py
index 4f3c8e74e2..5f01a5bcdd 100644
--- a/tests/unit/aiplatform/test_vision_models.py
+++ b/tests/unit/aiplatform/test_vision_models.py
@@ -17,9 +17,13 @@
# pylint: disable=protected-access,bad-continuation
+import base64
import importlib
+import io
import os
import tempfile
+from typing import Any, Dict
+import unittest
from unittest import mock
from google.cloud import aiplatform
@@ -35,6 +39,7 @@
from google.cloud.aiplatform.compat.types import (
publisher_model as gca_publisher_model,
)
+from vertexai import vision_models as ga_vision_models
from vertexai.preview import vision_models
from PIL import Image as PIL_Image
@@ -70,6 +75,50 @@
}
+_IMAGE_GENERATION_PUBLISHER_MODEL_DICT = {
+ "name": "publishers/google/models/imagegeneration",
+ "version_id": "002",
+ "open_source_category": "PROPRIETARY",
+ "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
+ "publisher_model_template": "projects/{project}/locations/{location}/publishers/google/models/imagegeneration@002",
+ "predict_schemata": {
+ "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/vision_generative_model_1.0.0.yaml",
+ "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/vision_generative_model_1.0.0.yaml",
+ "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/vision_generative_model_1.0.0.yaml",
+ },
+}
+
+
+def make_image_base64(width: int, height: int) -> str:
+ image: PIL_Image.Image = PIL_Image.new(mode="RGB", size=(width, height))
+ image_bytes = io.BytesIO()
+ image.save(image_bytes, format="PNG")
+ image_b64 = base64.b64encode(image_bytes.getvalue()).decode("utf-8")
+ return image_b64
+
+
+def make_image_generation_response(
+ width: int, height: int, count: int = 1
+) -> Dict[str, Any]:
+ predictions = []
+ for _ in range(count):
+ predictions.append(
+ {
+ "bytesBase64Encoded": make_image_base64(width, height),
+ "mimeType": "image/png",
+ }
+ )
+ return {"predictions": predictions}
+
+
+def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
+ predictions = {
+ "bytesBase64Encoded": make_image_base64(upscale_size, upscale_size),
+ "mimeType": "image/png",
+ }
+ return {"predictions": [predictions]}
+
+
def generate_image_from_file(
width: int = 100, height: int = 100
) -> vision_models.Image:
@@ -80,6 +129,297 @@ def generate_image_from_file(
return vision_models.Image.load_from_file(image_path)
+@pytest.mark.usefixtures("google_auth_mock")
+class TestImageGenerationModels:
+ """Unit tests for the image generation models."""
+
+ def setup_method(self):
+ importlib.reload(initializer)
+ importlib.reload(aiplatform)
+
+ def teardown_method(self):
+ initializer.global_pool.shutdown(wait=True)
+
+ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
+ """Gets the image generation model."""
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ )
+ with mock.patch.object(
+ target=model_garden_service_client.ModelGardenServiceClient,
+ attribute="get_publisher_model",
+ return_value=gca_publisher_model.PublisherModel(
+ _IMAGE_GENERATION_PUBLISHER_MODEL_DICT
+ ),
+ ) as mock_get_publisher_model:
+ model = vision_models.ImageGenerationModel.from_pretrained(
+ "imagegeneration@002"
+ )
+
+ mock_get_publisher_model.assert_called_once_with(
+ name="publishers/google/models/imagegeneration@002",
+ retry=base._DEFAULT_RETRY,
+ )
+
+ return model
+
+ def test_from_pretrained(self):
+ model = self._get_image_generation_model()
+ assert (
+ model._endpoint_name
+ == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002"
+ )
+
+ def test_generate_images(self):
+ """Tests the image generation model."""
+ model = self._get_image_generation_model()
+
+ width = 1024
+ # TODO(b/295946075) The service stopped supporting image sizes.
+ # height = 768
+ height = 1024
+ number_of_images = 4
+ seed = 1
+ guidance_scale = 15
+
+ image_generation_response = make_image_generation_response(
+ width=width, height=height, count=number_of_images
+ )
+ gca_predict_response = gca_prediction_service.PredictResponse()
+ gca_predict_response.predictions.extend(
+ image_generation_response["predictions"]
+ )
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ return_value=gca_predict_response,
+ ) as mock_predict:
+ prompt1 = "Astronaut riding a horse"
+ negative_prompt1 = "bad quality"
+ image_response = model.generate_images(
+ prompt=prompt1,
+ # Optional:
+ negative_prompt=negative_prompt1,
+ number_of_images=number_of_images,
+ # TODO(b/295946075) The service stopped supporting image sizes.
+ # width=width,
+ # height=height,
+ seed=seed,
+ guidance_scale=guidance_scale,
+ )
+ predict_kwargs = mock_predict.call_args[1]
+ actual_parameters = predict_kwargs["parameters"]
+ actual_instance = predict_kwargs["instances"][0]
+ assert actual_instance["prompt"] == prompt1
+ assert actual_instance["negativePrompt"] == negative_prompt1
+ # TODO(b/295946075) The service stopped supporting image sizes.
+ # assert actual_parameters["sampleImageSize"] == str(max(width, height))
+ # assert actual_parameters["aspectRatio"] == f"{width}:{height}"
+ assert actual_parameters["seed"] == seed
+ assert actual_parameters["guidanceScale"] == guidance_scale
+
+ assert len(image_response.images) == number_of_images
+ for idx, image in enumerate(image_response):
+ assert image._pil_image.size == (width, height)
+ assert image.generation_parameters
+ assert image.generation_parameters["prompt"] == prompt1
+ assert image.generation_parameters["negative_prompt"] == negative_prompt1
+ # TODO(b/295946075) The service stopped supporting image sizes.
+ # assert image.generation_parameters["width"] == width
+ # assert image.generation_parameters["height"] == height
+ assert image.generation_parameters["seed"] == seed
+ assert image.generation_parameters["guidance_scale"] == guidance_scale
+ assert image.generation_parameters["index_of_image_in_batch"] == idx
+ image.show()
+
+ # Test saving and loading images
+ with tempfile.TemporaryDirectory() as temp_dir:
+ image_path = os.path.join(temp_dir, "image.png")
+ image_response[0].save(location=image_path)
+ image1 = vision_models.GeneratedImage.load_from_file(image_path)
+ # assert image1._pil_image.size == (width, height)
+ assert image1.generation_parameters
+ assert image1.generation_parameters["prompt"] == prompt1
+
+ # Preparing mask
+ mask_path = os.path.join(temp_dir, "mask.png")
+ mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size)
+ mask_pil_image.save(mask_path, format="PNG")
+ mask_image = vision_models.Image.load_from_file(mask_path)
+
+ # Test generating image from base image
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ return_value=gca_predict_response,
+ ) as mock_predict:
+ prompt2 = "Ancient book style"
+ image_response2 = model.edit_image(
+ prompt=prompt2,
+ # Optional:
+ number_of_images=number_of_images,
+ seed=seed,
+ guidance_scale=guidance_scale,
+ base_image=image1,
+ mask=mask_image,
+ )
+ predict_kwargs = mock_predict.call_args[1]
+ actual_instance = predict_kwargs["instances"][0]
+ assert actual_instance["prompt"] == prompt2
+ assert actual_instance["image"]["bytesBase64Encoded"]
+ assert actual_instance["mask"]["image"]["bytesBase64Encoded"]
+
+ assert len(image_response2.images) == number_of_images
+ for image in image_response2:
+ assert image._pil_image.size == (width, height)
+ assert image.generation_parameters
+ assert image.generation_parameters["prompt"] == prompt2
+ assert image.generation_parameters["base_image_hash"]
+ assert image.generation_parameters["mask_hash"]
+
+ @unittest.skip(reason="b/295946075 The service stopped supporting image sizes.")
+ def test_generate_images_requests_square_images_by_default(self):
+ """Tests that the model class generates square image by default."""
+ model = self._get_image_generation_model()
+
+ image_size = 1024
+
+ # No height specified
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ ) as mock_predict:
+ model.generate_images(
+ prompt="test",
+ width=image_size,
+ )
+ predict_kwargs = mock_predict.call_args[1]
+ actual_parameters = predict_kwargs["parameters"]
+ assert actual_parameters["sampleImageSize"] == str(image_size)
+ assert "aspectRatio" not in actual_parameters
+
+ # No width specified
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ ) as mock_predict:
+ model.generate_images(
+ prompt="test",
+ height=image_size,
+ )
+ predict_kwargs = mock_predict.call_args[1]
+ actual_parameters = predict_kwargs["parameters"]
+ assert actual_parameters["sampleImageSize"] == str(image_size)
+ assert "aspectRatio" not in actual_parameters
+
+ # No width or height specified
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ ) as mock_predict:
+ model.generate_images(prompt="test")
+ predict_kwargs = mock_predict.call_args[1]
+ actual_parameters = predict_kwargs["parameters"]
+ assert "sampleImageSize" not in actual_parameters
+
+ def test_upscale_image_on_generated_image(self):
+ """Tests image upscaling on generated images."""
+ model = self._get_image_generation_model()
+
+ image_generation_response = make_image_generation_response(
+ count=1, height=1024, width=1024
+ )
+ gca_generation_response = gca_prediction_service.PredictResponse()
+ gca_generation_response.predictions.extend(
+ image_generation_response["predictions"]
+ )
+
+ image_upscale_response = make_image_upscale_response(upscale_size=2048)
+ gca_upscale_response = gca_prediction_service.PredictResponse()
+ gca_upscale_response.predictions.extend(image_upscale_response["predictions"])
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ return_value=gca_generation_response,
+ ):
+ prompt = "Ancient book style"
+ image_generation_response = model.generate_images(
+ prompt=prompt,
+ )
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ return_value=gca_upscale_response,
+ ) as mock_upscale:
+ upscaled_image = model.upscale_image(image=image_generation_response[0])
+
+ predict_kwargs = mock_upscale.call_args[1]
+ actual_instance = predict_kwargs["instances"][0]
+ assert actual_instance["image"]["bytesBase64Encoded"]
+
+ image_upscale_parameters = predict_kwargs["parameters"]
+ assert image_upscale_parameters["sampleImageSize"] == str(
+ upscaled_image._size[0]
+ )
+ assert image_upscale_parameters["mode"] == "upscale"
+
+ assert upscaled_image._image_bytes
+ assert upscaled_image.generation_parameters["prompt"] == prompt
+
+ def test_upscale_image_on_provided_image(self):
+ """Tests image upscaling on generated images."""
+ model = self._get_image_generation_model()
+
+ image_generation_response = make_image_generation_response(
+ count=1, height=1024, width=1024
+ )
+ gca_generation_response = gca_prediction_service.PredictResponse()
+ gca_generation_response.predictions.extend(
+ image_generation_response["predictions"]
+ )
+
+ image_upscale_response = make_image_upscale_response(upscale_size=4096)
+ gca_upscale_response = gca_prediction_service.PredictResponse()
+ gca_upscale_response.predictions.extend(image_upscale_response["predictions"])
+
+ with mock.patch.object(
+ target=prediction_service_client.PredictionServiceClient,
+ attribute="predict",
+ return_value=gca_upscale_response,
+ ) as mock_upscale:
+ test_image = generate_image_from_file(height=1024, width=1024)
+
+ upscaled_image = model.upscale_image(image=test_image, new_size=4096)
+
+ predict_kwargs = mock_upscale.call_args[1]
+ actual_instance = predict_kwargs["instances"][0]
+ assert actual_instance["image"]["bytesBase64Encoded"]
+
+ image_upscale_parameters = predict_kwargs["parameters"]
+ assert (
+ image_upscale_parameters["sampleImageSize"]
+ == str(upscaled_image._size[0])
+ == str(upscaled_image.generation_parameters["upscaled_image_size"])
+ )
+ assert image_upscale_parameters["mode"] == "upscale"
+
+ assert upscaled_image._image_bytes
+ assert isinstance(upscaled_image, vision_models.GeneratedImage)
+
+ def test_upscale_image_raises_if_not_1024x1024(self):
+ """Tests image upscaling on generated images."""
+ model = self._get_image_generation_model()
+
+ test_image = generate_image_from_file(height=100, width=100)
+
+ with pytest.raises(ValueError):
+ model.upscale_image(image=test_image)
+
+
@pytest.mark.usefixtures("google_auth_mock")
class ImageCaptioningModelTests:
"""Unit tests for the image captioning models."""
@@ -102,7 +442,9 @@ def test_get_captions(self):
attribute="get_publisher_model",
return_value=gca_publisher_model(_IMAGE_TEXT_PUBLISHER_MODEL_DICT),
):
- model = vision_models.ImageCaptioningModel.from_pretrained("imagetext@001")
+ model = ga_vision_models.ImageCaptioningModel.from_pretrained(
+ "imagetext@001"
+ )
image_captions = [
"Caption 1",
@@ -150,7 +492,7 @@ def test_get_captions(self):
_IMAGE_TEXT_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
- model = vision_models.ImageQnAModel.from_pretrained("imagetext@001")
+ model = ga_vision_models.ImageQnAModel.from_pretrained("imagetext@001")
mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/imagetext@001",
@@ -277,7 +619,7 @@ def test_image_embedding_model_with_only_text(self):
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
):
- model = vision_models.MultiModalEmbeddingModel.from_pretrained(
+ model = ga_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)
diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py
index f547b1d7fe..38dc7a1917 100644
--- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py
+++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py
@@ -2224,6 +2224,7 @@ def test_create_pipeline_job(request_type, transport: str = "grpc"):
network="network_value",
reserved_ip_ranges=["reserved_ip_ranges_value"],
template_uri="template_uri_value",
+ schedule_name="schedule_name_value",
)
response = client.create_pipeline_job(request)
@@ -2241,6 +2242,7 @@ def test_create_pipeline_job(request_type, transport: str = "grpc"):
assert response.network == "network_value"
assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"]
assert response.template_uri == "template_uri_value"
+ assert response.schedule_name == "schedule_name_value"
def test_create_pipeline_job_empty_call():
@@ -2289,6 +2291,7 @@ async def test_create_pipeline_job_async(
network="network_value",
reserved_ip_ranges=["reserved_ip_ranges_value"],
template_uri="template_uri_value",
+ schedule_name="schedule_name_value",
)
)
response = await client.create_pipeline_job(request)
@@ -2307,6 +2310,7 @@ async def test_create_pipeline_job_async(
assert response.network == "network_value"
assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"]
assert response.template_uri == "template_uri_value"
+ assert response.schedule_name == "schedule_name_value"
@pytest.mark.asyncio
@@ -2513,6 +2517,7 @@ def test_get_pipeline_job(request_type, transport: str = "grpc"):
network="network_value",
reserved_ip_ranges=["reserved_ip_ranges_value"],
template_uri="template_uri_value",
+ schedule_name="schedule_name_value",
)
response = client.get_pipeline_job(request)
@@ -2530,6 +2535,7 @@ def test_get_pipeline_job(request_type, transport: str = "grpc"):
assert response.network == "network_value"
assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"]
assert response.template_uri == "template_uri_value"
+ assert response.schedule_name == "schedule_name_value"
def test_get_pipeline_job_empty_call():
@@ -2573,6 +2579,7 @@ async def test_get_pipeline_job_async(
network="network_value",
reserved_ip_ranges=["reserved_ip_ranges_value"],
template_uri="template_uri_value",
+ schedule_name="schedule_name_value",
)
)
response = await client.get_pipeline_job(request)
@@ -2591,6 +2598,7 @@ async def test_get_pipeline_job_async(
assert response.network == "network_value"
assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"]
assert response.template_uri == "template_uri_value"
+ assert response.schedule_name == "schedule_name_value"
@pytest.mark.asyncio
diff --git a/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py
index 0ae10d5741..5eef17c4ac 100644
--- a/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py
+++ b/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py
@@ -2415,6 +2415,253 @@ async def test_read_tensorboard_usage_flattened_error_async():
)
+@pytest.mark.parametrize(
+ "request_type",
+ [
+ tensorboard_service.ReadTensorboardSizeRequest,
+ dict,
+ ],
+)
+def test_read_tensorboard_size(request_type, transport: str = "grpc"):
+ client = TensorboardServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.read_tensorboard_size), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = tensorboard_service.ReadTensorboardSizeResponse(
+ storage_size_byte=1826,
+ )
+ response = client.read_tensorboard_size(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == tensorboard_service.ReadTensorboardSizeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, tensorboard_service.ReadTensorboardSizeResponse)
+ assert response.storage_size_byte == 1826
+
+
+def test_read_tensorboard_size_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TensorboardServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.read_tensorboard_size), "__call__"
+ ) as call:
+ client.read_tensorboard_size()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == tensorboard_service.ReadTensorboardSizeRequest()
+
+
+@pytest.mark.asyncio
+async def test_read_tensorboard_size_async(
+ transport: str = "grpc_asyncio",
+ request_type=tensorboard_service.ReadTensorboardSizeRequest,
+):
+ client = TensorboardServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.read_tensorboard_size), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ tensorboard_service.ReadTensorboardSizeResponse(
+ storage_size_byte=1826,
+ )
+ )
+ response = await client.read_tensorboard_size(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == tensorboard_service.ReadTensorboardSizeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, tensorboard_service.ReadTensorboardSizeResponse)
+ assert response.storage_size_byte == 1826
+
+
+@pytest.mark.asyncio
+async def test_read_tensorboard_size_async_from_dict():
+ await test_read_tensorboard_size_async(request_type=dict)
+
+
+def test_read_tensorboard_size_field_headers():
+ client = TensorboardServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = tensorboard_service.ReadTensorboardSizeRequest()
+
+ request.tensorboard = "tensorboard_value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.read_tensorboard_size), "__call__"
+ ) as call:
+ call.return_value = tensorboard_service.ReadTensorboardSizeResponse()
+ client.read_tensorboard_size(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert (
+ "x-goog-request-params",
+ "tensorboard=tensorboard_value",
+ ) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_read_tensorboard_size_field_headers_async():
+ client = TensorboardServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = tensorboard_service.ReadTensorboardSizeRequest()
+
+ request.tensorboard = "tensorboard_value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.read_tensorboard_size), "__call__"
+ ) as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ tensorboard_service.ReadTensorboardSizeResponse()
+ )
+ await client.read_tensorboard_size(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert (
+ "x-goog-request-params",
+ "tensorboard=tensorboard_value",
+ ) in kw["metadata"]
+
+
+def test_read_tensorboard_size_flattened():
+ client = TensorboardServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.read_tensorboard_size), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = tensorboard_service.ReadTensorboardSizeResponse()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.read_tensorboard_size(
+ tensorboard="tensorboard_value",
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ arg = args[0].tensorboard
+ mock_val = "tensorboard_value"
+ assert arg == mock_val
+
+
+def test_read_tensorboard_size_flattened_error():
+ client = TensorboardServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.read_tensorboard_size(
+ tensorboard_service.ReadTensorboardSizeRequest(),
+ tensorboard="tensorboard_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_read_tensorboard_size_flattened_async():
+ client = TensorboardServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.read_tensorboard_size), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = tensorboard_service.ReadTensorboardSizeResponse()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ tensorboard_service.ReadTensorboardSizeResponse()
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.read_tensorboard_size(
+ tensorboard="tensorboard_value",
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ arg = args[0].tensorboard
+ mock_val = "tensorboard_value"
+ assert arg == mock_val
+
+
+@pytest.mark.asyncio
+async def test_read_tensorboard_size_flattened_error_async():
+ client = TensorboardServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.read_tensorboard_size(
+ tensorboard_service.ReadTensorboardSizeRequest(),
+ tensorboard="tensorboard_value",
+ )
+
+
@pytest.mark.parametrize(
"request_type",
[
@@ -9450,6 +9697,7 @@ def test_tensorboard_service_base_transport():
"list_tensorboards",
"delete_tensorboard",
"read_tensorboard_usage",
+ "read_tensorboard_size",
"create_tensorboard_experiment",
"get_tensorboard_experiment",
"update_tensorboard_experiment",
diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py
index 62cf6a5c91..bcc0847d18 100644
--- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py
+++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py
@@ -2008,22 +2008,19 @@ def test_parse_annotated_dataset_path():
def test_dataset_path():
project = "cuttlefish"
- location = "mussel"
- dataset = "winkle"
- expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(
+ dataset = "mussel"
+ expected = "projects/{project}/datasets/{dataset}".format(
project=project,
- location=location,
dataset=dataset,
)
- actual = MigrationServiceClient.dataset_path(project, location, dataset)
+ actual = MigrationServiceClient.dataset_path(project, dataset)
assert expected == actual
def test_parse_dataset_path():
expected = {
- "project": "nautilus",
- "location": "scallop",
- "dataset": "abalone",
+ "project": "winkle",
+ "dataset": "nautilus",
}
path = MigrationServiceClient.dataset_path(**expected)
@@ -2033,9 +2030,9 @@ def test_parse_dataset_path():
def test_dataset_path():
- project = "squid"
- location = "clam"
- dataset = "whelk"
+ project = "scallop"
+ location = "abalone"
+ dataset = "squid"
expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(
project=project,
location=location,
@@ -2047,9 +2044,9 @@ def test_dataset_path():
def test_parse_dataset_path():
expected = {
- "project": "octopus",
- "location": "oyster",
- "dataset": "nudibranch",
+ "project": "clam",
+ "location": "whelk",
+ "dataset": "octopus",
}
path = MigrationServiceClient.dataset_path(**expected)
@@ -2059,19 +2056,22 @@ def test_parse_dataset_path():
def test_dataset_path():
- project = "cuttlefish"
- dataset = "mussel"
- expected = "projects/{project}/datasets/{dataset}".format(
+ project = "oyster"
+ location = "nudibranch"
+ dataset = "cuttlefish"
+ expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(
project=project,
+ location=location,
dataset=dataset,
)
- actual = MigrationServiceClient.dataset_path(project, dataset)
+ actual = MigrationServiceClient.dataset_path(project, location, dataset)
assert expected == actual
def test_parse_dataset_path():
expected = {
- "project": "winkle",
+ "project": "mussel",
+ "location": "winkle",
"dataset": "nautilus",
}
path = MigrationServiceClient.dataset_path(**expected)
diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py
index ecac2496fe..1640f780c1 100644
--- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py
+++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py
@@ -2228,6 +2228,7 @@ def test_create_pipeline_job(request_type, transport: str = "grpc"):
network="network_value",
reserved_ip_ranges=["reserved_ip_ranges_value"],
template_uri="template_uri_value",
+ schedule_name="schedule_name_value",
)
response = client.create_pipeline_job(request)
@@ -2245,6 +2246,7 @@ def test_create_pipeline_job(request_type, transport: str = "grpc"):
assert response.network == "network_value"
assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"]
assert response.template_uri == "template_uri_value"
+ assert response.schedule_name == "schedule_name_value"
def test_create_pipeline_job_empty_call():
@@ -2293,6 +2295,7 @@ async def test_create_pipeline_job_async(
network="network_value",
reserved_ip_ranges=["reserved_ip_ranges_value"],
template_uri="template_uri_value",
+ schedule_name="schedule_name_value",
)
)
response = await client.create_pipeline_job(request)
@@ -2311,6 +2314,7 @@ async def test_create_pipeline_job_async(
assert response.network == "network_value"
assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"]
assert response.template_uri == "template_uri_value"
+ assert response.schedule_name == "schedule_name_value"
@pytest.mark.asyncio
@@ -2517,6 +2521,7 @@ def test_get_pipeline_job(request_type, transport: str = "grpc"):
network="network_value",
reserved_ip_ranges=["reserved_ip_ranges_value"],
template_uri="template_uri_value",
+ schedule_name="schedule_name_value",
)
response = client.get_pipeline_job(request)
@@ -2534,6 +2539,7 @@ def test_get_pipeline_job(request_type, transport: str = "grpc"):
assert response.network == "network_value"
assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"]
assert response.template_uri == "template_uri_value"
+ assert response.schedule_name == "schedule_name_value"
def test_get_pipeline_job_empty_call():
@@ -2577,6 +2583,7 @@ async def test_get_pipeline_job_async(
network="network_value",
reserved_ip_ranges=["reserved_ip_ranges_value"],
template_uri="template_uri_value",
+ schedule_name="schedule_name_value",
)
)
response = await client.get_pipeline_job(request)
@@ -2595,6 +2602,7 @@ async def test_get_pipeline_job_async(
assert response.network == "network_value"
assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"]
assert response.template_uri == "template_uri_value"
+ assert response.schedule_name == "schedule_name_value"
@pytest.mark.asyncio
diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py
index c40139c0e3..adf8d887d8 100644
--- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py
+++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py
@@ -1494,6 +1494,252 @@ async def test_explain_flattened_error_async():
)
+@pytest.mark.parametrize(
+ "request_type",
+ [
+ prediction_service.CountTokensRequest,
+ dict,
+ ],
+)
+def test_count_tokens(request_type, transport: str = "grpc"):
+ client = PredictionServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.count_tokens), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = prediction_service.CountTokensResponse(
+ total_tokens=1303,
+ total_billable_characters=2617,
+ )
+ response = client.count_tokens(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == prediction_service.CountTokensRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, prediction_service.CountTokensResponse)
+ assert response.total_tokens == 1303
+ assert response.total_billable_characters == 2617
+
+
+def test_count_tokens_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = PredictionServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.count_tokens), "__call__") as call:
+ client.count_tokens()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == prediction_service.CountTokensRequest()
+
+
+@pytest.mark.asyncio
+async def test_count_tokens_async(
+ transport: str = "grpc_asyncio", request_type=prediction_service.CountTokensRequest
+):
+ client = PredictionServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.count_tokens), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ prediction_service.CountTokensResponse(
+ total_tokens=1303,
+ total_billable_characters=2617,
+ )
+ )
+ response = await client.count_tokens(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == prediction_service.CountTokensRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, prediction_service.CountTokensResponse)
+ assert response.total_tokens == 1303
+ assert response.total_billable_characters == 2617
+
+
+@pytest.mark.asyncio
+async def test_count_tokens_async_from_dict():
+ await test_count_tokens_async(request_type=dict)
+
+
+def test_count_tokens_field_headers():
+ client = PredictionServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = prediction_service.CountTokensRequest()
+
+ request.endpoint = "endpoint_value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.count_tokens), "__call__") as call:
+ call.return_value = prediction_service.CountTokensResponse()
+ client.count_tokens(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert (
+ "x-goog-request-params",
+ "endpoint=endpoint_value",
+ ) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_count_tokens_field_headers_async():
+ client = PredictionServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = prediction_service.CountTokensRequest()
+
+ request.endpoint = "endpoint_value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.count_tokens), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ prediction_service.CountTokensResponse()
+ )
+ await client.count_tokens(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert (
+ "x-goog-request-params",
+ "endpoint=endpoint_value",
+ ) in kw["metadata"]
+
+
+def test_count_tokens_flattened():
+ client = PredictionServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.count_tokens), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = prediction_service.CountTokensResponse()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.count_tokens(
+ endpoint="endpoint_value",
+ instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)],
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ arg = args[0].endpoint
+ mock_val = "endpoint_value"
+ assert arg == mock_val
+ arg = args[0].instances
+ mock_val = [struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)]
+ assert arg == mock_val
+
+
+def test_count_tokens_flattened_error():
+ client = PredictionServiceClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.count_tokens(
+ prediction_service.CountTokensRequest(),
+ endpoint="endpoint_value",
+ instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)],
+ )
+
+
+@pytest.mark.asyncio
+async def test_count_tokens_flattened_async():
+ client = PredictionServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.count_tokens), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = prediction_service.CountTokensResponse()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ prediction_service.CountTokensResponse()
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.count_tokens(
+ endpoint="endpoint_value",
+ instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)],
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ arg = args[0].endpoint
+ mock_val = "endpoint_value"
+ assert arg == mock_val
+ arg = args[0].instances
+ mock_val = [struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)]
+ assert arg == mock_val
+
+
+@pytest.mark.asyncio
+async def test_count_tokens_flattened_error_async():
+ client = PredictionServiceAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.count_tokens(
+ prediction_service.CountTokensRequest(),
+ endpoint="endpoint_value",
+ instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)],
+ )
+
+
def test_credentials_transport_error():
# It is an error to provide credentials and a transport instance.
transport = transports.PredictionServiceGrpcTransport(
@@ -1635,6 +1881,7 @@ def test_prediction_service_base_transport():
"raw_predict",
"server_streaming_predict",
"explain",
+ "count_tokens",
"set_iam_policy",
"get_iam_policy",
"test_iam_permissions",
diff --git a/vertexai/language_models/__init__.py b/vertexai/language_models/__init__.py
index 9566691f29..8d16584ecb 100644
--- a/vertexai/language_models/__init__.py
+++ b/vertexai/language_models/__init__.py
@@ -23,6 +23,7 @@
CodeGenerationModel,
InputOutputTextPair,
TextEmbedding,
+ TextEmbeddingInput,
TextEmbeddingModel,
TextGenerationModel,
TextGenerationResponse,
@@ -37,6 +38,7 @@
"CodeGenerationModel",
"InputOutputTextPair",
"TextEmbedding",
+ "TextEmbeddingInput",
"TextEmbeddingModel",
"TextGenerationModel",
"TextGenerationResponse",
diff --git a/vertexai/language_models/_evaluatable_language_models.py b/vertexai/language_models/_evaluatable_language_models.py
new file mode 100644
index 0000000000..eb3423fc93
--- /dev/null
+++ b/vertexai/language_models/_evaluatable_language_models.py
@@ -0,0 +1,755 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Classes for working with language models."""
+
+import dataclasses
+import os
+from typing import Any, Dict, List, Optional, Type, TypeVar, Union
+
+from google.cloud import storage
+
+from google.cloud import aiplatform
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import initializer as aiplatform_initializer
+from google.cloud.aiplatform import utils as aiplatform_utils
+from google.cloud.aiplatform.utils import gcs_utils
+from vertexai._model_garden import _model_garden_models
+
+from google.cloud.aiplatform.compat.services import (
+ model_garden_service_client,
+)
+from google.cloud.aiplatform.compat.types import (
+ pipeline_state as gca_pipeline_state,
+)
+
+try:
+ import pandas
+except ImportError:
+ pandas = None
+
+
+_LOGGER = base.Logger(__name__)
+
+# Model Evaluation constants
+_TEXT_CLASSIFICATION_TASK_NAME = "text-classification"
+_TEXT_GENERATION_TASK_NAME = "text-generation"
+_QA_TASK_NAME = "question-answering"
+_SUMMARIZATION_TASK_NAME = "summarization"
+
+_EVALUATION_TASKS = frozenset(
+ [
+ _TEXT_CLASSIFICATION_TASK_NAME,
+ _TEXT_GENERATION_TASK_NAME,
+ _QA_TASK_NAME,
+ _SUMMARIZATION_TASK_NAME,
+ ]
+)
+
+
+_TEXT_CLASSIFICATION_TEMPLATE_URL = "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation-llm-classification-pipeline"
+_TEXT_GENERATION_QA_SUMMARIZATION_TEMPLATE_URL = "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation-llm-text-generation-pipeline"
+
+_EVALUATION_TEMPLATE_VERSION_TAG = "1.0.1"
+
+_EVALUATION_TEMPLATE_URLS = {
+ _TEXT_CLASSIFICATION_TASK_NAME: f"{_TEXT_CLASSIFICATION_TEMPLATE_URL}/{_EVALUATION_TEMPLATE_VERSION_TAG}",
+ _TEXT_GENERATION_TASK_NAME: f"{_TEXT_GENERATION_QA_SUMMARIZATION_TEMPLATE_URL}/{_EVALUATION_TEMPLATE_VERSION_TAG}",
+ _QA_TASK_NAME: f"{_TEXT_GENERATION_QA_SUMMARIZATION_TEMPLATE_URL}/{_EVALUATION_TEMPLATE_VERSION_TAG}",
+ _SUMMARIZATION_TASK_NAME: f"{_TEXT_GENERATION_QA_SUMMARIZATION_TEMPLATE_URL}/{_EVALUATION_TEMPLATE_VERSION_TAG}",
+}
+
+
+_EVALUATION_PIPELINE_COMPONENT_IDENTIFIER = "fpc-llm-evaluation"
+
+# TODO: update this when BP removes the input size limit
+_BATCH_PREDICTION_ROW_LIMIT = 1000
+
+_EVAL_SUPPORTED_BASE_MODELS = ["text-bison@001"]
+
+T = TypeVar("T", bound="_EvaluationMetricBase")
+
+
+def _check_dataset_is_within_size_limit(
+ data: "pandas.DataFrame",
+) -> None:
+
+ if len(data) < _BATCH_PREDICTION_ROW_LIMIT:
+ return
+
+ raise ValueError(
+ f"Your evaluation dataset size exceeds the limit of {_BATCH_PREDICTION_ROW_LIMIT}"
+ )
+
+
+def _get_model_resource_name_and_validate(
+ model_name: str,
+ model_info: _model_garden_models._ModelInfo,
+) -> str:
+ """Returns the resource name string for the model.
+
+ Model Registry resource names will stay the same. For Publisher Models, we need to
+ pass the full resource name (publishers/google/models/text-bison@001) to the evaluation
+ template and ensure the base model supports evaluation.
+
+ Args:
+ model_name (str):
+ Required. The full resource name of the Model Registry model or base publisher model
+ to run evaluation on.
+ model_info (_model_garden_models._ModelInfo):
+ Required. The _ModelInfo object for the instance.
+
+ Returns:
+ The formatted model_name string.
+
+ Raises:
+ ValueError
+ If a base PublisherModel was provided and the model doesn't support evaluation.
+ """
+
+ if "publishers/" not in model_name:
+ # Model Registry resource
+ return model_name
+
+ else:
+ if model_info.tuning_model_id in _EVAL_SUPPORTED_BASE_MODELS:
+ return f"{model_info.publisher_model_resource.name}@{model_info.publisher_model_resource.version_id}"
+
+ raise ValueError(
+ f"The provided model {model_name} does not support evaluation."
+ )
+
+
+def _get_template_url(task_name: str) -> Optional[str]:
+ """Returns the pipeline template to use for the evaluation task.
+
+ Args:
+ task_name (str):
+ Required. The name of the evaluation task to run.
+
+ Returns:
+ The evaluation pipeline template path.
+ """
+
+ return _EVALUATION_TEMPLATE_URLS.get(task_name)
+
+
+@dataclasses.dataclass
+class _EvaluationTaskSpec:
+ """Base class for task-specific model evaluation configuration parameters.
+
+ This class should not be instantiated directly, instead use the subclass corresponding
+ to your evaluation task.
+
+ Args:
+ ground_truth_data (Union[List[str], str, pandas.DataFrame]):
+ Required. The ground truth data to use for this evaluation job. This can be
+ either a Pandas DataFrame, a Cloud Storage URI of your JSONL data file, or a list of multiple
+ JSONL files on Cloud Storage.
+
+ Raises:
+ ValueError:
+ If task_spec.ground_truth_data is formatted incorrectly.
+ If task_spec.ground_truth_data is a Pandas DataFrame and exceeds 1000 rows.
+ If task_spec.ground_truth_data is not a string, list, or Pandas DataFrame.
+ """
+
+ ground_truth_data: Union[List[str], str, "pandas.DataFrame"]
+
+ @property
+ def task_name(self) -> str:
+ pass
+
+ def __post_init__(self):
+
+ if isinstance(self.ground_truth_data, str):
+ self.ground_truth_data = [self.ground_truth_data]
+
+ if isinstance(self.ground_truth_data, list) and not all(
+ item.startswith("gs://") for item in self.ground_truth_data
+ ):
+ raise ValueError("Please provide a valid GCS URI starting with 'gs://'")
+
+ if pandas and isinstance(self.ground_truth_data, pandas.DataFrame):
+
+ _check_dataset_is_within_size_limit(self.ground_truth_data)
+
+
+@dataclasses.dataclass
+class EvaluationTextClassificationSpec(_EvaluationTaskSpec):
+ """Spec for text classification model evaluation tasks.
+
+ Args:
+ target_column_name (str):
+ Required. The label column in the dataset provided in `ground_truth_data`. Required when task_name='text-classification'.
+ class_names (List[str]):
+ Required. A list of all possible label names in your dataset. Required when task_name='text-classification'.
+ """
+
+ target_column_name: str
+ class_names: List[str]
+
+ @property
+ def task_name(self) -> str:
+ return "text-classification"
+
+
+@dataclasses.dataclass
+class EvaluationTextGenerationSpec(_EvaluationTaskSpec):
+ """Spec for text generation model evaluation tasks."""
+
+ @property
+ def task_name(self) -> str:
+ return "text-generation"
+
+
+@dataclasses.dataclass
+class EvaluationQuestionAnsweringSpec(_EvaluationTaskSpec):
+ """Spec for question answering model evaluation tasks."""
+
+ task_name: str = "question-answering"
+
+
+@dataclasses.dataclass
+class EvaluationTextSummarizationSpec(_EvaluationTaskSpec):
+ """Spec for text summarization model evaluation tasks."""
+
+ task_name: str = "summarization"
+
+
+@dataclasses.dataclass
+class _EvaluationMetricBase:
+ """Base class for returned evaulation metrics"""
+
+ @property
+ def input_dataset_paths(self) -> str:
+ """The Google Cloud Storage paths to the dataset used for this evaluation."""
+ pass
+
+ @property
+ def task_name(self) -> str:
+ """The type of evaluation task for the evaluation.."""
+ pass
+
+
+@dataclasses.dataclass
+class EvaluationMetric(_EvaluationMetricBase):
+ """The evaluation metric response.
+
+ Args:
+ bleu (float):
+ Optional. BLEU (Bilingual evauation understudy). Scores based on sacrebleu implementation.
+ rougeLSum (float):
+ Optional. ROUGE-L (Longest Common Subsequence) scoring at summary level.
+ """
+
+ bleu: Optional[float] = None
+ rougeLSum: Optional[float] = None
+
+
+@dataclasses.dataclass
+class EvaluationClassificationMetric(_EvaluationMetricBase):
+ """The evaluation metric response for classification metrics.
+
+ Args:
+ label_name (str):
+ Optional. The name of the label associated with the metrics. This is only
+ returned when `only_summary_metrics=False` is passed to evaluate().
+ auPrc (float):
+ Optional. The area under the precision recall curve.
+ auRoc (float):
+ Optional. The area under the receiver operating characteristic curve.
+ logLoss (float):
+ Optional. Logarithmic loss.
+ confidenceMetrics (List[Dict[str, Any]]):
+ Optional. This is only returned when `only_summary_metrics=False` is
+ passed to evaluate().
+ confusionMatrix (Dict[str, Any]):
+ Optional. This is only returned when `only_summary_metrics=False` is
+ passed to evaluate().
+ """
+
+ label_name: Optional[str] = None
+ auPrc: Optional[float] = None
+ auRoc: Optional[float] = None
+ logLoss: Optional[float] = None
+ confidenceMetrics: Optional[List[Dict[str, Any]]] = None
+ confusionMatrix: Optional[Dict[str, Any]] = None
+
+
+@dataclasses.dataclass
+class EvaluationSlicedClassificationMetric(_EvaluationMetricBase):
+ """The evaluation metric slices returned for classification metrics.
+
+ This is returned when `only_summary_metrics=False` is passed to evaluate().
+
+ Args:
+ overall_metrics (EvaluationClassificationMetric):
+ The evaluation metrics across all slices of data
+ slices (List[EvaluationClassificationMetric]):
+ The evaluation metrics for each label slice.
+ """
+
+ overall_metrics: Optional[EvaluationClassificationMetric] = None
+ slices: Optional[List[EvaluationClassificationMetric]] = None
+
+
+def _populate_eval_template_params(
+ task_spec: _EvaluationTaskSpec,
+ model_name: str,
+ service_account: Optional[str] = None,
+ machine_type: Optional[str] = None,
+ network: Optional[str] = None,
+ encryption_spec_key_name: Optional[str] = None,
+) -> Dict[str, Any]:
+ """Populates a dictionary of template parameters for the evaluation PipelineJob.
+
+ Args:
+ task_spec (EvaluationTaskSpec):
+ The EvaluationTaskSpec passed to evaluate() for this job
+ model_name (str):
+ The resource name of the model being evaluated. Either a PublisherModel or
+ ModelRegistry resource name.
+ service_account (Optional[str]):
+ The default service account for workload run-as account.
+ machine_type (Optional[str]):
+ Optional. The type of the machine to run the evaluation job on.
+ network (Optional[str]):
+ Optional.
+ encryption_spec_key_name (Optional[str]):
+ Optional.
+
+ Returns:
+ Dict[str, Any]:
+ A dictionary of template parameter names and values to be passed to the PipelineJob
+ running the model evaluation.
+ """
+
+ ground_truth_data_gcs_path = task_spec.ground_truth_data
+
+ staging_bucket = aiplatform_initializer.global_config.staging_bucket
+
+ if not staging_bucket:
+ staging_bucket = (
+ gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist()
+ )
+
+ timestamped_eval_directory = (
+ f"evaluation_data_{aiplatform_utils.timestamped_unique_name()}"
+ )
+
+ if isinstance(task_spec.ground_truth_data, pandas.DataFrame):
+
+ # Convert to jsonl file and upload to gcs
+ dataset_uri = os.path.join(
+ staging_bucket,
+ timestamped_eval_directory,
+ "eval_data.jsonl",
+ )
+
+ gcs_utils._upload_pandas_df_to_gcs(
+ df=task_spec.ground_truth_data, upload_gcs_path=dataset_uri
+ )
+ ground_truth_data_gcs_path = [dataset_uri]
+
+ template_params = {
+ "project": aiplatform_initializer.global_config.project,
+ "location": aiplatform_initializer.global_config.location,
+ "batch_predict_gcs_destination_output_uri": f"{staging_bucket}/{timestamped_eval_directory}",
+ "model_name": model_name,
+ "batch_predict_gcs_source_uris": ground_truth_data_gcs_path,
+ "service_account": service_account,
+ "machine_type": machine_type,
+ "encrytion_spec_key_name": encryption_spec_key_name
+ or aiplatform_initializer.global_config.encryption_spec_key_name,
+ "network": network or aiplatform_initializer.global_config.network,
+ }
+
+ if task_spec.task_name == _TEXT_CLASSIFICATION_TASK_NAME:
+ template_params["evaluation_class_labels"] = task_spec.class_names
+ template_params["target_field_name"] = task_spec.target_column_name
+ else:
+ template_params["evaluation_task"] = task_spec.task_name
+
+ return template_params
+
+
+# TODO (b/285947054): update to use public pipeline contract
+def _get_gcs_uri_from_pipeline_task_details(
+ pipeline_job: aiplatform.PipelineJob,
+) -> Optional[str]:
+ """Gets the GCS URI from the PipelineJob output.
+
+ Args:
+ pipeline_job (aiplatform.PipelineJob)
+ The PipelineJob resource to get the metrics GCS URI from
+
+ Returns:
+ The GCS URI of the evaluation metrics as a string.
+ """
+
+ for task in pipeline_job.task_details:
+ if task.task_name == pipeline_job.name and "evaluation_metrics" in task.outputs:
+ return task.outputs["evaluation_metrics"].artifacts[0].uri
+
+
+def _convert_metrics_dict_to_response_type(
+ metrics_json: Dict[str, Any],
+ metric_type: Type[T],
+ metric_name: Optional[str] = None,
+) -> EvaluationClassificationMetric:
+ metrics_response = metric_type()
+ if metric_name:
+ metrics_response.label_name = metric_name
+
+ for metric, value in metrics_json.items():
+ if hasattr(metrics_response, metric):
+ setattr(metrics_response, metric, value)
+ return metrics_response
+
+
+def _format_classification_metrics(
+ metrics: Dict[str, Any]
+) -> EvaluationSlicedClassificationMetric:
+ """Reformats classification metrics returned by the eval pipeline to make them more readable.
+
+ Returned metrics are of type EvaluationSlicedClassificationMetric, with `overall` representing
+ the metrics for all data, and `slices` representing the metrics for each label in the dataset.
+
+ Example schema of reformatted metrics:
+
+ EvaluationSlicedClassificationMetrics(
+ overall_metrics=EvaluationClassificationMetric(
+ auPrc=...
+ )
+ slices=[
+ EvaluationClassificationMetric(
+ label_name="overall",
+ auPrc=...,
+ ...
+ ),
+ EvaluationClassificationMetric(
+ label_name="label_1",
+ auPrc=...,
+ ...
+ ),
+ EvaluationClassificationMetric(
+ label_name="label_2",
+ auPrc=...,
+ ...
+ )
+ ]
+ )
+ """
+
+ reformatted_metrics = EvaluationSlicedClassificationMetric()
+
+ # TODO: see if we can do this without relying on specific keys, i.e. slicedMetrics
+
+ # First add overall metrics
+ overall_metrics = _convert_metrics_dict_to_response_type(
+ metrics_json=metrics["slicedMetrics"][0]["metrics"]["classification"],
+ metric_type=EvaluationClassificationMetric,
+ )
+ reformatted_metrics.overall_metrics = overall_metrics
+
+ sliced_metrics = []
+
+ # Then add metrics for each slice
+ for idx in range(1, len(metrics["slicedMetrics"])):
+ metric_slice_name = metrics["slicedMetrics"][idx]["singleOutputSlicingSpec"][
+ "value"
+ ]
+
+ sliced_metric = _convert_metrics_dict_to_response_type(
+ metrics_json=metrics["slicedMetrics"][idx]["metrics"]["classification"],
+ metric_type=EvaluationClassificationMetric,
+ metric_name=metric_slice_name,
+ )
+ sliced_metrics.append(sliced_metric)
+
+ reformatted_metrics.sliced_metrics = sliced_metrics
+
+ return reformatted_metrics
+
+
+def _get_metrics_from_gcs_uri(
+ gcs_uri: str,
+) -> Union[
+ EvaluationMetric,
+ EvaluationClassificationMetric,
+ EvaluationSlicedClassificationMetric,
+]:
+ """Downloads evaluation metrics from GCS path."""
+
+ storage_client = storage.Client(
+ credentials=aiplatform_initializer.global_config.credentials
+ )
+
+ metrics_json = storage.Blob.from_string(
+ uri=gcs_uri, client=storage_client
+ ).download_as_text()
+
+ # Sliced classification metrics case, format data
+ if "slicedMetrics" in metrics_json:
+ return _format_classification_metrics(metrics_json)
+ # If classification metrics don't contain slices, use EvaluationClassificationMetric type
+ if "auPrc" in metrics_json:
+ metrics_response = _convert_metrics_dict_to_response_type(
+ metrics_json=metrics_json,
+ metric_type=EvaluationClassificationMetric,
+ )
+ # All other metric types
+ else:
+ metrics_response = _convert_metrics_dict_to_response_type(
+ metrics_json=metrics_json,
+ metric_type=EvaluationMetric,
+ )
+ return metrics_response
+
+
+def _get_metrics_from_pipeline_task_details(
+ pipeline_job: aiplatform.PipelineJob,
+) -> Union[EvaluationMetric, EvaluationClassificationMetric]:
+ """Gets the evaluation metrics from the PipelineJob TaskDetails.
+
+ Args:
+ pipeline_job (aiplatform.PipelineJob)
+ The PipelineJob resource to get the metrics from
+
+ Returns:
+ A dictionary with the evaluation metrics
+ """
+ metrics = {}
+
+ # TODO (b/292076101): this now uses a public pipelines contract, but still relies on task_details
+ for task in pipeline_job.task_details:
+ if task.task_name == pipeline_job.name:
+ for output in task.outputs:
+ for metric_name, metric_value in (
+ task.outputs[output].artifacts[0].metadata.items()
+ ):
+ metrics[metric_name] = metric_value
+
+ if "auPrc" in metrics:
+ metrics_response = EvaluationClassificationMetric()
+ else:
+ metrics_response = EvaluationMetric()
+
+ for metric, value in metrics.items():
+ if hasattr(metrics_response, metric):
+ setattr(metrics_response, metric, value)
+ return metrics_response
+
+
+class _LanguageModelEvaluationJob:
+ """Represents a model evaluation job for LLM models.
+
+ These evaluation jobs are run as a Vertex Pipeline.
+ """
+
+ def __init__(
+ self,
+ pipeline_job: aiplatform.PipelineJob,
+ ):
+ self._pipeline_job = pipeline_job
+
+ def result(
+ self, *, only_summary_metrics: bool
+ ) -> Union[EvaluationMetric, EvaluationClassificationMetric]:
+ """Blocks on completion of the model evaluation PipelineJob and returns metrics."""
+
+ self._pipeline_job.wait()
+
+ if only_summary_metrics:
+ return _get_metrics_from_pipeline_task_details(self._pipeline_job)
+ else:
+ gcs_uri = _get_gcs_uri_from_pipeline_task_details(self._pipeline_job)
+ if gcs_uri:
+ return _get_metrics_from_gcs_uri(gcs_uri)
+
+
+class _EvaluatableLanguageModel:
+ """Mixin class for LLMs that support model evaluation."""
+
+ # TODO (b/282975912): convert training job specific args to a TrainingConfig
+ def evaluate(
+ self,
+ *,
+ task_spec: _EvaluationTaskSpec,
+ only_summary_metrics: Optional[bool] = True,
+ machine_type: Optional[str] = None,
+ ) -> Union[
+ EvaluationMetric,
+ EvaluationClassificationMetric,
+ EvaluationSlicedClassificationMetric,
+ ]:
+ """Runs model evaluation using the provided input and ground truth data.
+
+ This creates an evaluation job and blocks until the job completes, about
+ 10 - 20 minutes.
+
+ Example:
+ ```
+ model = TextGenerationModel.from_pretrained("text-bison@001")
+ eval_metrics = model.evaluate(
+ task_spec=EvaluationTextGenerationSpec(
+ ground_truth_data="gs://my-bucket/ground-truth.jsonl",
+ )
+ )
+ ```
+
+ Args:
+ task_spec (_EvaluationTaskSpec):
+ Required. The configuration spec for your model evaluation job. Choose the spec corresponding
+ with the evaluation task you are performing, one of: EvaluationClassificationSpec, EvaluationTextGenerationSpec,
+ EvaluationTextSummarizationSpec, EvaluationQuestionAnsweringSpec.
+
+ For example, a valid classification `task_spec` is:
+ EvaluationTextClassificationSpec(
+ ground_truth_data=["gs://bucket/path/to/your/data.jsonl"],
+ class_names=["cheddar", "gouda", "camembert"],
+ target_column_name="cheese_type",
+ )
+ only_summary_metrics (bool):
+ Optional. Setting this field to False only affects the metrics returned for text classification tasks.
+ When False, text classification metrics will include additional sliced metrics fields, with metrics for
+ each label slice in the data.
+ machine_type (str):
+ Optional. The type of the machine to run the evaluation job on. The default value is "e2-highmem-16". For
+ tasks with a large evaluation dataset, a bigger machine type may be required.
+ For more details about this input config, see
+ https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types.
+
+ Returns:
+ Union[EvaluationMetric, EvaluationClassificationMetric, List[EvaluationClassificationMetric]]
+ The evaluation metrics from this evaluation job. When `only_summary_metrics=False` is passed
+ and the evaluation task type is 'text-classification', the return type will be List[EvaluationClassificationMetric],
+ where each value in the list is the metrics associated with a particular classification label.
+ """
+
+ model_info = _model_garden_models._get_model_info(
+ self._model_id,
+ schema_to_class_map={self._INSTANCE_SCHEMA_URI: type(self)},
+ )
+ model_name = _get_model_resource_name_and_validate(
+ model_name=self._model_resource_name, model_info=model_info
+ )
+
+ # TODO(b/296402511): get service_account from aiplatform_initializer and pass it to the template here and to PipelineJob after cl/539823838 is submitted
+ template_params = _populate_eval_template_params(
+ task_spec=task_spec,
+ model_name=model_name,
+ machine_type=machine_type,
+ network=aiplatform_initializer.global_config.network,
+ encryption_spec_key_name=aiplatform_initializer.global_config.encryption_spec_key_name,
+ )
+
+ template_path = _get_template_url(task_spec.task_name)
+
+ pipeline_job = aiplatform.PipelineJob(
+ template_path=template_path,
+ parameter_values=template_params,
+ display_name=f"llm-eval-sdk-{aiplatform_utils.timestamped_unique_name()}",
+ )
+ pipeline_job.submit()
+
+ eval_job = _LanguageModelEvaluationJob(pipeline_job=pipeline_job)
+
+ _LOGGER.info(
+ "Your evaluation job is running and will take 15-20 minutes to complete. Click on the PipelineJob link to view progress."
+ )
+
+ # NOTE: only_summary_metrics is passed because getting metrics from the artifact is faster than downloading from GCS
+ # GCS is only needed for additional metrics for text-classification tasks
+ return eval_job.result(only_summary_metrics=only_summary_metrics)
+
+ def list_evaluation_metrics(
+ self,
+ *,
+ task_name: Optional[str] = None,
+ only_summary_metrics: Optional[bool] = True,
+ ) -> List[Union[EvaluationMetric, EvaluationClassificationMetric]]:
+ """Lists the evaluation metrics from all evaluation jobs run on this model.
+
+ Args:
+ task_name (str):
+ Optional. The task name to return evaluation metrics for. If provided, this will only return evaluation
+ metrics for tasks of the provided type. This matches the possible values passed to EvaluationTaskType.task_name,
+ and must be one of 'text-generation', 'text-classification', 'summarization', or 'question-answering'.
+
+ Returns:
+ Dict[str, Any]
+ The evaluation metrics from all evaluation jobs run on this model.
+
+ """
+
+ model_name = self._model_resource_name
+
+ publisher_model_parts = model_garden_service_client.ModelGardenServiceClient.parse_publisher_model_path(
+ "".join(model_name.rpartition("publishers")[1:])
+ )
+
+ if publisher_model_parts:
+ model_id = publisher_model_parts["model"]
+ model_name = f"publishers/google/models/{model_id}"
+
+ filters = f'metadata.component_type.string_value={_EVALUATION_PIPELINE_COMPONENT_IDENTIFIER} AND metadata."input:model_name".string_value={model_name} AND (metadata."input:evaluation_task".string_value={_TEXT_GENERATION_TASK_NAME} OR metadata."input:evaluation_task".string_value={_SUMMARIZATION_TASK_NAME} OR metadata."input:evaluation_task".string_value={_QA_TASK_NAME} OR metadata."input:evaluation_task".string_value={_TEXT_CLASSIFICATION_TASK_NAME})'
+
+ # NOTE: when task_name is appended to the filter the block of OR filters in `filters` above becomes a no-op
+ if task_name:
+ filters += f' AND metadata."input:evaluation_task".string_value={task_name}'
+
+ filtered_pipeline_executions = aiplatform.Execution.list(
+ filter=filters,
+ project=aiplatform_initializer.global_config.project,
+ location=aiplatform_initializer.global_config.location,
+ credentials=aiplatform_initializer.global_config.credentials,
+ )
+
+ model_eval_metrics = []
+
+ # TODO (b/285950380): improve performance of this method
+ for pipeline_execution in filtered_pipeline_executions:
+ if "pipeline_job_resource_name" not in pipeline_execution.metadata:
+ continue
+
+ pipeline_job_resource = aiplatform.PipelineJob.get(
+ resource_name=pipeline_execution.metadata["pipeline_job_resource_name"]
+ )
+ eval_job_state = pipeline_job_resource._gca_resource.state
+
+ if (
+ eval_job_state
+ != gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
+ ):
+ continue
+
+ metrics = None
+
+ if only_summary_metrics:
+ metrics = _get_metrics_from_pipeline_task_details(pipeline_job_resource)
+ else:
+ gcs_uri = _get_gcs_uri_from_pipeline_task_details(pipeline_job_resource)
+ if gcs_uri:
+ metrics = _get_metrics_from_gcs_uri(gcs_uri)
+
+ metrics.input_dataset_paths = pipeline_execution.metadata[
+ "input:batch_predict_gcs_source_uris"
+ ]
+ metrics.task_name = pipeline_execution.metadata["input:evaluation_task"]
+
+ model_eval_metrics.append(metrics)
+
+ return model_eval_metrics
diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py
index 620bc4e708..cfa484d9dc 100644
--- a/vertexai/language_models/_language_models.py
+++ b/vertexai/language_models/_language_models.py
@@ -15,15 +15,20 @@
"""Classes for working with language models."""
import dataclasses
-from typing import Any, Dict, List, Optional, Sequence, Union
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
import warnings
from google.cloud import aiplatform
+from google.cloud.aiplatform import _streaming_prediction
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform import utils as aiplatform_utils
+from google.cloud.aiplatform.compat import types as aiplatform_types
from google.cloud.aiplatform.utils import gcs_utils
from vertexai._model_garden import _model_garden_models
+from vertexai.language_models import (
+ _evaluatable_language_models,
+)
try:
import pandas
@@ -84,6 +89,13 @@ def _model_resource_name(self) -> str:
return self._endpoint.list_models()[0].model
+@dataclasses.dataclass
+class _PredictionRequest:
+ """A single-instance prediction request."""
+ instance: Dict[str, Any]
+ parameters: Optional[Dict[str, Any]] = None
+
+
class _TunableModelMixin(_LanguageModel):
"""Model that can be tuned."""
@@ -137,8 +149,112 @@ def tune_model(
self,
training_data: Union[str, "pandas.core.frame.DataFrame"],
*,
- train_steps: int = 1000,
+ train_steps: Optional[int] = None,
learning_rate: Optional[float] = None,
+ learning_rate_multiplier: Optional[float] = None,
+ tuning_job_location: Optional[str] = None,
+ tuned_model_location: Optional[str] = None,
+ model_display_name: Optional[str] = None,
+ tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
+ default_context: Optional[str] = None,
+ ) -> "_LanguageModelTuningJob":
+ """Tunes a model based on training data.
+
+ This method launches and returns an asynchronous model tuning job.
+ Usage:
+ ```
+ tuning_job = model.tune_model(...)
+ ... do some other work
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
+ ```
+
+ Args:
+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
+ The dataset schema is model-specific.
+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
+ train_steps: Number of training batches to tune on (batch size is 8 samples).
+ learning_rate: Deprecated. Use learning_rate_multiplier instead.
+ Learning rate to use in tuning.
+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
+ tuning_job_location: GCP location where the tuning job should be run.
+ Only "europe-west4" and "us-central1" locations are supported for now.
+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
+ model_display_name: Custom display name for the tuned model.
+ tuning_evaluation_spec: Specification for the model evaluation during tuning.
+ default_context: The context to use for all training samples by default.
+
+ Returns:
+ A `LanguageModelTuningJob` object that represents the tuning job.
+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
+
+ Raises:
+ ValueError: If the "tuning_job_location" value is not supported
+ ValueError: If the "tuned_model_location" value is not supported
+ RuntimeError: If the model does not support tuning
+ """
+ tuning_parameters = {}
+ if train_steps is not None:
+ tuning_parameters["train_steps"] = train_steps
+ if learning_rate is not None:
+ _LOGGER.warning(
+ "The learning_rate parameter is deprecated."
+ "Use the learning_rate_multiplier parameter instead."
+ )
+ tuning_parameters["learning_rate"] = learning_rate
+ if learning_rate_multiplier is not None:
+ tuning_parameters["learning_rate_multiplier"] = learning_rate_multiplier
+ eval_spec = tuning_evaluation_spec
+ if eval_spec is not None:
+ if isinstance(eval_spec.evaluation_data, str):
+ if eval_spec.evaluation_data.startswith("gs://"):
+ tuning_parameters["evaluation_data_uri"] = eval_spec.evaluation_data
+ else:
+ raise ValueError("evaluation_data should be a GCS URI")
+ else:
+ raise TypeError("evaluation_data should be a URI string")
+ if eval_spec.evaluation_interval is not None:
+ tuning_parameters["evaluation_interval"] = eval_spec.evaluation_interval
+ if eval_spec.enable_early_stopping is not None:
+ tuning_parameters[
+ "enable_early_stopping"
+ ] = eval_spec.enable_early_stopping
+ if eval_spec.tensorboard is not None:
+ if isinstance(eval_spec.tensorboard, aiplatform.Tensorboard):
+ if eval_spec.tensorboard.location != tuning_job_location:
+ raise ValueError(
+ "The Tensorboard must be in the same location as the tuning job."
+ )
+ tuning_parameters[
+ "tensorboard_resource_id"
+ ] = eval_spec.tensorboard.resource_name
+ elif isinstance(eval_spec.tensorboard, str):
+ resource_name_parts = aiplatform.Tensorboard._parse_resource_name(
+ eval_spec.tensorboard
+ )
+ if resource_name_parts["location"] != tuning_job_location:
+ raise ValueError(
+ "The Tensorboard must be in the same location as the tuning job."
+ )
+ tuning_parameters["tensorboard_resource_id"] = eval_spec.tensorboard
+ else:
+ raise TypeError("tensorboard should be a URI string")
+
+ if default_context:
+ tuning_parameters["default_context"] = default_context
+
+ return self._tune_model(
+ training_data=training_data,
+ tuning_parameters=tuning_parameters,
+ tuning_job_location=tuning_job_location,
+ tuned_model_location=tuned_model_location,
+ model_display_name=model_display_name,
+ )
+
+ def _tune_model(
+ self,
+ training_data: Union[str, "pandas.core.frame.DataFrame"],
+ *,
+ tuning_parameters: Dict[str, Any],
tuning_job_location: Optional[str] = None,
tuned_model_location: Optional[str] = None,
model_display_name: Optional[str] = None,
@@ -151,8 +267,7 @@ def tune_model(
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
The dataset schema is model-specific.
See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
- train_steps: Number of training batches to tune on (batch size is 8 samples).
- learning_rate: Learning rate for the tuning
+ tuning_parameters: Tuning pipeline parameter values.
tuning_job_location: GCP location where the tuning job should be run.
Only "europe-west4" and "us-central1" locations are supported for now.
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
@@ -184,11 +299,10 @@ def tune_model(
raise RuntimeError(f"The {self._model_id} model does not support tuning")
pipeline_job = _launch_tuning_job(
training_data=training_data,
- train_steps=train_steps,
model_id=model_info.tuning_model_id,
tuning_pipeline_uri=model_info.tuning_pipeline_uri,
+ tuning_parameters=tuning_parameters,
model_display_name=model_display_name,
- learning_rate=learning_rate,
tuning_job_location=tuning_job_location,
)
@@ -196,11 +310,281 @@ def tune_model(
base_model=self,
job=pipeline_job,
)
- self._job = job
- tuned_model = job.result()
- # The UXR study attendees preferred to tune model in place
+ return job
+
+
+class _TunableTextModelMixin(_TunableModelMixin):
+ """Text model that can be tuned."""
+
+ def tune_model(
+ self,
+ training_data: Union[str, "pandas.core.frame.DataFrame"],
+ *,
+ train_steps: Optional[int] = None,
+ learning_rate_multiplier: Optional[float] = None,
+ tuning_job_location: Optional[str] = None,
+ tuned_model_location: Optional[str] = None,
+ model_display_name: Optional[str] = None,
+ tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
+ ) -> "_LanguageModelTuningJob":
+ """Tunes a model based on training data.
+
+ This method launches and returns an asynchronous model tuning job.
+ Usage:
+ ```
+ tuning_job = model.tune_model(...)
+ ... do some other work
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
+
+ Args:
+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
+ The dataset schema is model-specific.
+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
+ train_steps: Number of training batches to tune on (batch size is 8 samples).
+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
+ tuning_job_location: GCP location where the tuning job should be run.
+ Only "europe-west4" and "us-central1" locations are supported for now.
+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
+ model_display_name: Custom display name for the tuned model.
+ tuning_evaluation_spec: Specification for the model evaluation during tuning.
+
+ Returns:
+ A `LanguageModelTuningJob` object that represents the tuning job.
+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
+
+ Raises:
+ ValueError: If the "tuning_job_location" value is not supported
+ ValueError: If the "tuned_model_location" value is not supported
+ RuntimeError: If the model does not support tuning
+ """
+ # Note: Chat models do not support default_context
+ return super().tune_model(
+ training_data=training_data,
+ train_steps=train_steps,
+ learning_rate_multiplier=learning_rate_multiplier,
+ tuning_job_location=tuning_job_location,
+ tuned_model_location=tuned_model_location,
+ model_display_name=model_display_name,
+ tuning_evaluation_spec=tuning_evaluation_spec,
+ )
+
+
+class _PreviewTunableTextModelMixin(_TunableModelMixin):
+ """Text model that can be tuned."""
+
+ def tune_model(
+ self,
+ training_data: Union[str, "pandas.core.frame.DataFrame"],
+ *,
+ train_steps: int = 1000,
+ learning_rate: Optional[float] = None,
+ learning_rate_multiplier: Optional[float] = None,
+ tuning_job_location: Optional[str] = None,
+ tuned_model_location: Optional[str] = None,
+ model_display_name: Optional[str] = None,
+ tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
+ ) -> "_LanguageModelTuningJob":
+ """Tunes a model based on training data.
+
+ This method launches a model tuning job, waits for completion,
+ updates the model in-place. This method returns job object for forward
+ compatibility.
+ In the future (GA), this method will become asynchronous and will stop
+ updating the model in-place.
+
+ Usage:
+ ```
+ tuning_job = model.tune_model(...) # Blocks until tuning is complete
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
+ ```
+
+ Args:
+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
+ The dataset schema is model-specific.
+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
+ train_steps: Number of training batches to tune on (batch size is 8 samples).
+ learning_rate: Deprecated. Use learning_rate_multiplier instead.
+ Learning rate to use in tuning.
+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
+ tuning_job_location: GCP location where the tuning job should be run.
+ Only "europe-west4" and "us-central1" locations are supported for now.
+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
+ model_display_name: Custom display name for the tuned model.
+ tuning_evaluation_spec: Specification for the model evaluation during tuning.
+
+ Returns:
+ A `LanguageModelTuningJob` object that represents the tuning job.
+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
+
+ Raises:
+ ValueError: If the "tuning_job_location" value is not supported
+ ValueError: If the "tuned_model_location" value is not supported
+ RuntimeError: If the model does not support tuning
+ """
+ # Note: Chat models do not support default_context
+ job = super().tune_model(
+ training_data=training_data,
+ train_steps=train_steps,
+ learning_rate=learning_rate,
+ learning_rate_multiplier=learning_rate_multiplier,
+ tuning_job_location=tuning_job_location,
+ tuned_model_location=tuned_model_location,
+ model_display_name=model_display_name,
+ tuning_evaluation_spec=tuning_evaluation_spec,
+ )
+ tuned_model = job.get_tuned_model()
+ self._endpoint = tuned_model._endpoint
+ self._endpoint_name = tuned_model._endpoint_name
+ return job
+
+
+class _TunableChatModelMixin(_TunableModelMixin):
+ """Chat model that can be tuned."""
+
+ def tune_model(
+ self,
+ training_data: Union[str, "pandas.core.frame.DataFrame"],
+ *,
+ train_steps: Optional[int] = None,
+ learning_rate_multiplier: Optional[float] = None,
+ tuning_job_location: Optional[str] = None,
+ tuned_model_location: Optional[str] = None,
+ model_display_name: Optional[str] = None,
+ default_context: Optional[str] = None,
+ ) -> "_LanguageModelTuningJob":
+ """Tunes a model based on training data.
+
+ This method launches and returns an asynchronous model tuning job.
+ Usage:
+ ```
+ tuning_job = model.tune_model(...)
+ ... do some other work
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
+ ```
+
+ Args:
+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
+ The dataset schema is model-specific.
+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
+ train_steps: Number of training batches to tune on (batch size is 8 samples).
+ learning_rate: Deprecated. Use learning_rate_multiplier instead.
+ Learning rate to use in tuning.
+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
+ tuning_job_location: GCP location where the tuning job should be run.
+ Only "europe-west4" and "us-central1" locations are supported for now.
+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
+ model_display_name: Custom display name for the tuned model.
+ default_context: The context to use for all training samples by default.
+
+ Returns:
+ A `LanguageModelTuningJob` object that represents the tuning job.
+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
+
+ Raises:
+ ValueError: If the "tuning_job_location" value is not supported
+ ValueError: If the "tuned_model_location" value is not supported
+ RuntimeError: If the model does not support tuning
+ """
+ # Note: Chat models do not support tuning_evaluation_spec
+ return super().tune_model(
+ training_data=training_data,
+ train_steps=train_steps,
+ learning_rate_multiplier=learning_rate_multiplier,
+ tuning_job_location=tuning_job_location,
+ tuned_model_location=tuned_model_location,
+ model_display_name=model_display_name,
+ default_context=default_context,
+ )
+
+
+class _PreviewTunableChatModelMixin(_TunableModelMixin):
+ """Chat model that can be tuned."""
+
+ def tune_model(
+ self,
+ training_data: Union[str, "pandas.core.frame.DataFrame"],
+ *,
+ train_steps: int = 1000,
+ learning_rate: Optional[float] = None,
+ learning_rate_multiplier: Optional[float] = None,
+ tuning_job_location: Optional[str] = None,
+ tuned_model_location: Optional[str] = None,
+ model_display_name: Optional[str] = None,
+ default_context: Optional[str] = None,
+ ) -> "_LanguageModelTuningJob":
+ """Tunes a model based on training data.
+
+ This method launches a model tuning job, waits for completion,
+ updates the model in-place. This method returns job object for forward
+ compatibility.
+ In the future (GA), this method will become asynchronous and will stop
+ updating the model in-place.
+
+ Usage:
+ ```
+ tuning_job = model.tune_model(...) # Blocks until tuning is complete
+ tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
+ ```
+
+ Args:
+ training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
+ The dataset schema is model-specific.
+ See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
+ train_steps: Number of training batches to tune on (batch size is 8 samples).
+ learning_rate: Deprecated. Use learning_rate_multiplier instead.
+ Learning rate to use in tuning.
+ learning_rate_multiplier: Learning rate multiplier to use in tuning.
+ tuning_job_location: GCP location where the tuning job should be run.
+ Only "europe-west4" and "us-central1" locations are supported for now.
+ tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
+ model_display_name: Custom display name for the tuned model.
+ default_context: The context to use for all training samples by default.
+
+ Returns:
+ A `LanguageModelTuningJob` object that represents the tuning job.
+ Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
+
+ Raises:
+ ValueError: If the "tuning_job_location" value is not supported
+ ValueError: If the "tuned_model_location" value is not supported
+ RuntimeError: If the model does not support tuning
+ """
+ # Note: Chat models do not support tuning_evaluation_spec
+ job = super().tune_model(
+ training_data=training_data,
+ train_steps=train_steps,
+ learning_rate=learning_rate,
+ learning_rate_multiplier=learning_rate_multiplier,
+ tuning_job_location=tuning_job_location,
+ tuned_model_location=tuned_model_location,
+ model_display_name=model_display_name,
+ default_context=default_context,
+ )
+ tuned_model = job.get_tuned_model()
self._endpoint = tuned_model._endpoint
self._endpoint_name = tuned_model._endpoint_name
+ return job
+
+
+@dataclasses.dataclass
+class TuningEvaluationSpec:
+ """Specification for model evaluation to perform during tuning.
+
+ Attributes:
+ evaluation_data: GCS URI of the evaluation dataset. This will run
+ model evaluation as part of the tuning job.
+ evaluation_interval: The evaluation will run at every
+ evaluation_interval tuning steps. Default: 20.
+ enable_early_stopping: If True, the tuning may stop early before
+ completing all the tuning steps. Requires evaluation_data.
+ tensorboard: Vertex Tensorboard where to write the evaluation metrics.
+ The Tensorboard must be in the same location as the tuning job.
+ """
+
+ evaluation_data: str
+ evaluation_interval: Optional[int] = None
+ enable_early_stopping: Optional[bool] = None
+ tensorboard: Optional[Union[aiplatform.Tensorboard, str]] = None
@dataclasses.dataclass
@@ -222,6 +606,11 @@ class TextGenerationResponse:
def __repr__(self):
return self.text
+ @property
+ def raw_prediction_response(self) -> aiplatform.models.Prediction:
+ """Raw prediction response."""
+ return self._prediction_response
+
class _TextGenerationModel(_LanguageModel):
"""TextGenerationModel represents a general language model.
@@ -247,6 +636,7 @@ def predict(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
+ stop_sequences: Optional[List[str]] = None,
) -> "TextGenerationResponse":
"""Gets model response for a single prompt.
@@ -256,6 +646,7 @@ def predict(
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
+ stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -267,6 +658,7 @@ def predict(
temperature=temperature,
top_k=top_k,
top_p=top_p,
+ stop_sequences=stop_sequences,
)[0]
def _batch_predict(
@@ -276,6 +668,7 @@ def _batch_predict(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
+ stop_sequences: Optional[List[str]] = None,
) -> List["TextGenerationResponse"]:
"""Gets model response for a single prompt.
@@ -285,6 +678,7 @@ def _batch_predict(
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
+ stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A list of `TextGenerationResponse` objects that contain the texts produced by the model.
@@ -304,6 +698,9 @@ def _batch_predict(
if top_k:
prediction_parameters["topK"] = top_k
+ if stop_sequences:
+ prediction_parameters["stopSequences"] = stop_sequences
+
prediction_response = self._endpoint.predict(
instances=instances,
parameters=prediction_parameters,
@@ -327,6 +724,70 @@ def _batch_predict(
)
return results
+ def predict_streaming(
+ self,
+ prompt: str,
+ *,
+ max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
+ temperature: Optional[float] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ ) -> Iterator[TextGenerationResponse]:
+ """Gets a streaming model response for a single prompt.
+
+ The result is a stream (generator) of partial responses.
+
+ Args:
+ prompt: Question to ask the model.
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
+
+ Yields:
+ A stream of `TextGenerationResponse` objects that contain partial
+ responses produced by the model.
+ """
+ prediction_service_client = self._endpoint._prediction_client
+ # Note: "prompt", not "content" like in the non-streaming case. b/294462691
+ instance = {"prompt": prompt}
+ prediction_parameters = {}
+
+ if max_output_tokens:
+ prediction_parameters["maxDecodeSteps"] = max_output_tokens
+
+ if temperature is not None:
+ prediction_parameters["temperature"] = temperature
+
+ if top_p:
+ prediction_parameters["topP"] = top_p
+
+ if top_k:
+ prediction_parameters["topK"] = top_k
+
+ for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
+ prediction_service_client=prediction_service_client,
+ endpoint_name=self._endpoint_name,
+ instance=instance,
+ parameters=prediction_parameters,
+ ):
+ safety_attributes_dict = prediction_dict.get("safetyAttributes", {})
+ prediction_obj = aiplatform.models.Prediction(
+ predictions=[prediction_dict],
+ deployed_model_id="",
+ )
+ yield TextGenerationResponse(
+ text=prediction_dict["content"],
+ _prediction_response=prediction_obj,
+ is_blocked=safety_attributes_dict.get("blocked", False),
+ safety_attributes=dict(
+ zip(
+ safety_attributes_dict.get("categories") or [],
+ safety_attributes_dict.get("scores") or [],
+ )
+ ),
+ )
+
class _ModelWithBatchPredict(_LanguageModel):
"""Model that supports batch prediction."""
@@ -433,12 +894,17 @@ def batch_predict(
)
-class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):
+class TextGenerationModel(
+ _TextGenerationModel, _TunableTextModelMixin, _ModelWithBatchPredict
+):
pass
class _PreviewTextGenerationModel(
- _TextGenerationModel, _TunableModelMixin, _PreviewModelWithBatchPredict
+ _TextGenerationModel,
+ _PreviewTunableTextModelMixin,
+ _PreviewModelWithBatchPredict,
+ _evaluatable_language_models._EvaluatableLanguageModel,
):
# Do not add docstring so that it's inherited from the base class.
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
@@ -555,8 +1021,33 @@ def send_message(
return response_obj
+@dataclasses.dataclass
+class TextEmbeddingInput:
+ """Structural text embedding input.
+
+ Attributes:
+ text: The main text content to embed.
+ task_type: The name of the downstream task the embeddings will be used for.
+ Valid values:
+ RETRIEVAL_QUERY
+ Specifies the given text is a query in a search/retrieval setting.
+ RETRIEVAL_DOCUMENT
+ Specifies the given text is a document from the corpus being searched.
+ SEMANTIC_SIMILARITY
+ Specifies the given text will be used for STS.
+ CLASSIFICATION
+ Specifies that the given text will be classified.
+ CLUSTERING
+ Specifies that the embeddings will be used for clustering.
+ title: Optional identifier of the text content.
+ """
+ text: str
+ task_type: Optional[str] = None
+ title: Optional[str] = None
+
+
class TextEmbeddingModel(_LanguageModel):
- """TextEmbeddingModel converts text into a vector of floating-point numbers.
+ """TextEmbeddingModel class calculates embeddings for the given texts.
Examples::
@@ -574,36 +1065,76 @@ class TextEmbeddingModel(_LanguageModel):
"gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml"
)
- def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]:
- instances = [{"content": str(text)} for text in texts]
+ def get_embeddings(self,
+ texts: List[Union[str, TextEmbeddingInput]],
+ *,
+ auto_truncate: bool = True,
+ ) -> List["TextEmbedding"]:
+ """Calculates embeddings for the given texts.
+
+ Args:
+ texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
+ auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
+
+ Returns:
+ A list of `TextEmbedding` objects.
+ """
+ instances = []
+ for text in texts:
+ if isinstance(text, TextEmbeddingInput):
+ instance = {"content": text.text}
+ if text.task_type:
+ instance["taskType"] = text.task_type
+ if text.title:
+ instance["title"] = text.title
+ elif isinstance(text, str):
+ instance = {"content": text}
+ else:
+ raise TypeError(f"Unsupported text embedding input type: {text}.")
+ instances.append(instance)
+ parameters = {"autoTruncate": auto_truncate}
prediction_response = self._endpoint.predict(
instances=instances,
+ parameters=parameters,
)
- return [
- TextEmbedding(
- values=prediction["embeddings"]["values"],
+ results = []
+ for prediction in prediction_response.predictions:
+ embeddings = prediction["embeddings"]
+ statistics = embeddings["statistics"]
+ result = TextEmbedding(
+ values=embeddings["values"],
+ statistics=TextEmbeddingStatistics(
+ token_count=statistics["token_count"],
+ truncated=statistics["truncated"],
+ ),
_prediction_response=prediction_response,
)
- for prediction in prediction_response.predictions
- ]
+ results.append(result)
+
+ return results
class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
+@dataclasses.dataclass
+class TextEmbeddingStatistics:
+ """Text embedding statistics."""
+
+ token_count: int
+ truncated: bool
+
+
+@dataclasses.dataclass
class TextEmbedding:
- """Contains text embedding vector."""
+ """Text embedding vector and statistics."""
- def __init__(
- self,
- values: List[float],
- _prediction_response: Any = None,
- ):
- self.values = values
- self._prediction_response = _prediction_response
+ values: List[float]
+ statistics: TextEmbeddingStatistics
+ _prediction_response: aiplatform.models.Prediction = None
@dataclasses.dataclass
@@ -637,11 +1168,14 @@ def start_chat(
*,
context: Optional[str] = None,
examples: Optional[List[InputOutputTextPair]] = None,
- max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
+ max_output_tokens: Optional[
+ int
+ ] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
+ stop_sequences: Optional[List[str]] = None,
) -> "ChatSession":
"""Starts a chat session with the model.
@@ -655,6 +1189,7 @@ def start_chat(
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
message_history: A list of previously sent and received messages.
+ stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `ChatSession` object.
@@ -668,6 +1203,7 @@ def start_chat(
top_k=top_k,
top_p=top_p,
message_history=message_history,
+ stop_sequences=stop_sequences,
)
@@ -699,7 +1235,7 @@ class ChatModel(_ChatModelBase):
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
-class _PreviewChatModel(ChatModel, _TunableModelMixin):
+class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
@@ -742,11 +1278,11 @@ def start_chat(
model=self,
max_output_tokens=max_output_tokens,
temperature=temperature,
- message_history=message_history
+ message_history=message_history,
)
-class _PreviewCodeChatModel(CodeChatModel, _TunableModelMixin):
+class _PreviewCodeChatModel(CodeChatModel, _TunableChatModelMixin):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
@@ -761,11 +1297,14 @@ def __init__(
model: _ChatModelBase,
context: Optional[str] = None,
examples: Optional[List[InputOutputTextPair]] = None,
- max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
+ max_output_tokens: Optional[
+ int
+ ] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
+ stop_sequences: Optional[List[str]] = None,
):
self._model = model
self._context = context
@@ -775,13 +1314,14 @@ def __init__(
self._top_k = top_k
self._top_p = top_p
self._message_history: List[ChatMessage] = message_history or []
+ self._stop_sequences = stop_sequences
@property
def message_history(self) -> List[ChatMessage]:
"""List of previous messages."""
return self._message_history
- def send_message(
+ def _prepare_request(
self,
message: str,
*,
@@ -789,8 +1329,9 @@ def send_message(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
- ) -> "TextGenerationResponse":
- """Sends message to the language model and gets a response.
+ stop_sequences: Optional[List[str]] = None,
+ ) -> _PredictionRequest:
+ """Prepares a request for the language model.
Args:
message: Message to send to the model
@@ -802,9 +1343,10 @@ def send_message(
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
+ stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
- A `TextGenerationResponse` object that contains the text produced by the model.
+ A `_PredictionRequest` object.
"""
prediction_parameters = {}
@@ -825,6 +1367,10 @@ def send_message(
if top_k:
prediction_parameters["topK"] = top_k
+ stop_sequences = stop_sequences or self._stop_sequences
+ if stop_sequences:
+ prediction_parameters["stopSequences"] = stop_sequences
+
message_structs = []
for past_message in self._message_history:
message_structs.append(
@@ -852,27 +1398,90 @@ def send_message(
for example in self._examples
]
- prediction_response = self._model._endpoint.predict(
- instances=[prediction_instance],
+ return _PredictionRequest(
+ instance=prediction_instance,
parameters=prediction_parameters,
)
- prediction = prediction_response.predictions[0]
+ @classmethod
+ def _parse_chat_prediction_response(
+ cls,
+ prediction_response: aiplatform.models.Prediction,
+ prediction_idx: int = 0,
+ candidate_idx: int = 0,
+ ) -> TextGenerationResponse:
+ """Parses prediction response for chat models.
+
+ Args:
+ prediction_response: Prediction response received from the model
+ prediction_idx: Index of the prediction to parse.
+ candidate_idx: Index of the candidate to parse.
+
+ Returns:
+ A `TextGenerationResponse` object.
+ """
+ prediction = prediction_response.predictions[prediction_idx]
# ! Note: For chat models, the safetyAttributes is a list.
- safety_attributes = prediction["safetyAttributes"][0]
- response_obj = TextGenerationResponse(
- text=prediction["candidates"][0]["content"]
+ safety_attributes = prediction["safetyAttributes"][candidate_idx]
+ return TextGenerationResponse(
+ text=prediction["candidates"][candidate_idx]["content"]
if prediction.get("candidates")
else None,
_prediction_response=prediction_response,
is_blocked=safety_attributes.get("blocked", False),
safety_attributes=dict(
zip(
- safety_attributes.get("categories", []),
- safety_attributes.get("scores", []),
+ # Unlike with normal prediction, in streaming prediction
+ # categories and scores can be None
+ safety_attributes.get("categories") or [],
+ safety_attributes.get("scores") or [],
)
),
)
+
+ def send_message(
+ self,
+ message: str,
+ *,
+ max_output_tokens: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ stop_sequences: Optional[List[str]] = None,
+ ) -> "TextGenerationResponse":
+ """Sends message to the language model and gets a response.
+
+ Args:
+ message: Message to send to the model
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+ stop_sequences: Customized stop sequences to stop the decoding process.
+
+ Returns:
+ A `TextGenerationResponse` object that contains the text produced by the model.
+ """
+ prediction_request = self._prepare_request(
+ message=message,
+ max_output_tokens=max_output_tokens,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ stop_sequences=stop_sequences,
+ )
+
+ prediction_response = self._model._endpoint.predict(
+ instances=[prediction_request.instance],
+ parameters=prediction_request.parameters,
+ )
+ response_obj = self._parse_chat_prediction_response(
+ prediction_response=prediction_response
+ )
response_text = response_obj.text
self._message_history.append(
@@ -884,6 +1493,71 @@ def send_message(
return response_obj
+ def send_message_streaming(
+ self,
+ message: str,
+ *,
+ max_output_tokens: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ ) -> Iterator[TextGenerationResponse]:
+ """Sends message to the language model and gets a streamed response.
+
+ The response is only added to the history once it's fully read.
+
+ Args:
+ message: Message to send to the model
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+
+ Yields:
+ A stream of `TextGenerationResponse` objects that contain partial
+ responses produced by the model.
+ """
+ prediction_request = self._prepare_request(
+ message=message,
+ max_output_tokens=max_output_tokens,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+ prediction_service_client = self._model._endpoint._prediction_client
+
+ full_response_text = ""
+
+ for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
+ prediction_service_client=prediction_service_client,
+ endpoint_name=self._model._endpoint_name,
+ instance=prediction_request.instance,
+ parameters=prediction_request.parameters,
+ ):
+ prediction_response = aiplatform.models.Prediction(
+ predictions=[prediction_dict],
+ deployed_model_id="",
+ )
+ text_generation_response = self._parse_chat_prediction_response(
+ prediction_response=prediction_response
+ )
+ full_response_text += text_generation_response.text
+ yield text_generation_response
+
+ # We only add the question and answer to the history if/when the answer
+ # was read fully. Otherwise, the answer would have been truncated.
+ self._message_history.append(
+ ChatMessage(content=message, author=self.USER_AUTHOR)
+ )
+ self._message_history.append(
+ ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR)
+ )
+
class ChatSession(_ChatSessionBase):
"""ChatSession represents a chat session with a language model.
@@ -896,11 +1570,14 @@ def __init__(
model: ChatModel,
context: Optional[str] = None,
examples: Optional[List[InputOutputTextPair]] = None,
- max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
+ max_output_tokens: Optional[
+ int
+ ] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
+ stop_sequences: Optional[List[str]] = None,
):
super().__init__(
model=model,
@@ -911,6 +1588,7 @@ def __init__(
top_k=top_k,
top_p=top_p,
message_history=message_history,
+ stop_sequences=stop_sequences,
)
@@ -959,6 +1637,34 @@ def send_message(
temperature=temperature,
)
+ def send_message_streaming(
+ self,
+ message: str,
+ *,
+ max_output_tokens: Optional[int] = None,
+ temperature: Optional[float] = None,
+ ) -> Iterator[TextGenerationResponse]:
+ """Sends message to the language model and gets a streamed response.
+
+ The response is only added to the history once it's fully read.
+
+ Args:
+ message: Message to send to the model
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
+ Uses the value specified when calling `ChatModel.start_chat` by default.
+
+ Returns:
+ A stream of `TextGenerationResponse` objects that contain partial
+ responses produced by the model.
+ """
+ return super().send_message_streaming(
+ message=message,
+ max_output_tokens=max_output_tokens,
+ temperature=temperature,
+ )
+
class CodeGenerationModel(_LanguageModel):
"""A language model that generates code.
@@ -982,21 +1688,24 @@ class CodeGenerationModel(_LanguageModel):
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
_DEFAULT_MAX_OUTPUT_TOKENS = 128
- def predict(
+ def _create_prediction_request(
self,
prefix: str,
suffix: Optional[str] = None,
*,
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
temperature: Optional[float] = None,
- ) -> "TextGenerationResponse":
- """Gets model response for a single prompt.
+ stop_sequences: Optional[List[str]] = None,
+ ) -> _PredictionRequest:
+ """Creates a code generation prediction request.
Args:
prefix: Code before the current point.
suffix: Code after the current point.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
+ stop_sequences: Customized stop sequences to stop the decoding process.
+
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
@@ -1013,9 +1722,43 @@ def predict(
if max_output_tokens:
prediction_parameters["maxOutputTokens"] = max_output_tokens
+ if stop_sequences:
+ prediction_parameters["stopSequences"] = stop_sequences
+
+ return _PredictionRequest(instance=instance, parameters=prediction_parameters)
+
+ def predict(
+ self,
+ prefix: str,
+ suffix: Optional[str] = None,
+ *,
+ max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
+ temperature: Optional[float] = None,
+ stop_sequences: Optional[List[str]] = None,
+ ) -> "TextGenerationResponse":
+ """Gets model response for a single prompt.
+
+ Args:
+ prefix: Code before the current point.
+ suffix: Code after the current point.
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
+ temperature: Controls the randomness of predictions. Range: [0, 1].
+ stop_sequences: Customized stop sequences to stop the decoding process.
+
+ Returns:
+ A `TextGenerationResponse` object that contains the text produced by the model.
+ """
+ prediction_request = self._create_prediction_request(
+ prefix=prefix,
+ suffix=suffix,
+ max_output_tokens=max_output_tokens,
+ temperature=temperature,
+ stop_sequences=stop_sequences,
+ )
+
prediction_response = self._endpoint.predict(
- instances=[instance],
- parameters=prediction_parameters,
+ instances=[prediction_request.instance],
+ parameters=prediction_request.parameters,
)
return TextGenerationResponse(
@@ -1023,6 +1766,51 @@ def predict(
_prediction_response=prediction_response,
)
+ def predict_streaming(
+ self,
+ prefix: str,
+ suffix: Optional[str] = None,
+ *,
+ max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
+ temperature: Optional[float] = None,
+ ) -> Iterator[TextGenerationResponse]:
+ """Predicts the code based on previous code.
+
+ The result is a stream (generator) of partial responses.
+
+ Args:
+ prefix: Code before the current point.
+ suffix: Code after the current point.
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
+ temperature: Controls the randomness of predictions. Range: [0, 1].
+
+ Yields:
+ A stream of `TextGenerationResponse` objects that contain partial
+ responses produced by the model.
+ """
+ prediction_request = self._create_prediction_request(
+ prefix=prefix,
+ suffix=suffix,
+ max_output_tokens=max_output_tokens,
+ temperature=temperature,
+ )
+
+ prediction_service_client = self._endpoint._prediction_client
+ for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
+ prediction_service_client=prediction_service_client,
+ endpoint_name=self._endpoint_name,
+ instance=prediction_request.instance,
+ parameters=prediction_request.parameters,
+ ):
+ prediction_obj = aiplatform.models.Prediction(
+ predictions=[prediction_dict],
+ deployed_model_id="",
+ )
+ yield TextGenerationResponse(
+ text=prediction_dict["content"],
+ _prediction_response=prediction_obj,
+ )
+
class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
@@ -1043,11 +1831,12 @@ def __init__(
base_model: _LanguageModel,
job: aiplatform.PipelineJob,
):
+ """Internal constructor. Do not call directly."""
self._base_model = base_model
self._job = job
self._model: Optional[_LanguageModel] = None
- def result(self) -> "_LanguageModel":
+ def get_tuned_model(self) -> "_LanguageModel":
"""Blocks until the tuning is complete and returns a `LanguageModel` object."""
if self._model:
return self._model
@@ -1074,11 +1863,12 @@ def result(self) -> "_LanguageModel":
return self._model
@property
- def status(self):
- """Job status"""
+ def _status(self) -> Optional[aiplatform_types.pipeline_state.PipelineState]:
+ """Job status."""
return self._job.state
- def cancel(self):
+ def _cancel(self):
+ """Cancels the job."""
self._job.cancel()
@@ -1113,50 +1903,39 @@ def _launch_tuning_job(
training_data: Union[str, "pandas.core.frame.DataFrame"],
model_id: str,
tuning_pipeline_uri: str,
- train_steps: Optional[int] = None,
+ tuning_parameters: Dict[str, Any],
model_display_name: Optional[str] = None,
- learning_rate: Optional[float] = None,
tuning_job_location: str = _TUNING_LOCATIONS[0],
) -> aiplatform.PipelineJob:
output_dir_uri = _generate_tuned_model_dir_uri(model_id=model_id)
if isinstance(training_data, str):
- dataset_uri = training_data
+ dataset_name_or_uri = training_data
elif pandas and isinstance(training_data, pandas.DataFrame):
dataset_uri = _uri_join(output_dir_uri, "training_data.jsonl")
gcs_utils._upload_pandas_df_to_gcs(
df=training_data, upload_gcs_path=dataset_uri
)
-
+ dataset_name_or_uri = dataset_uri
else:
raise TypeError(f"Unsupported training_data type: {type(training_data)}")
- job = _launch_tuning_job_on_jsonl_data(
- model_id=model_id,
- dataset_name_or_uri=dataset_uri,
- train_steps=train_steps,
- tuning_pipeline_uri=tuning_pipeline_uri,
- model_display_name=model_display_name,
- learning_rate=learning_rate,
- tuning_job_location=tuning_job_location,
- )
- return job
-
-
-def _launch_tuning_job_on_jsonl_data(
- model_id: str,
- dataset_name_or_uri: str,
- tuning_pipeline_uri: str,
- train_steps: Optional[int] = None,
- learning_rate: Optional[float] = None,
- model_display_name: Optional[str] = None,
- tuning_job_location: str = _TUNING_LOCATIONS[0],
-) -> aiplatform.PipelineJob:
if not model_display_name:
# Creating a human-readable model display name
- name = f"{model_id} tuned for {train_steps} steps"
+ name = f"{model_id} tuned"
+
+ train_steps = tuning_parameters.get("train_steps")
+ if train_steps:
+ name += f" for {train_steps} steps"
+
+ learning_rate = tuning_parameters.get("learning_rate")
if learning_rate:
name += f" with learning rate {learning_rate}"
+
+ learning_rate_multiplier = tuning_parameters.get("learning_rate_multiplier")
+ if learning_rate_multiplier:
+ name += f" with learning_rate_multiplier={learning_rate_multiplier}"
+
name += " on "
# Truncating the start of the dataset URI to keep total length <= 128.
max_display_name_length = 128
@@ -1169,7 +1948,6 @@ def _launch_tuning_job_on_jsonl_data(
model_display_name = name[:max_display_name_length]
pipeline_arguments = {
- "train_steps": train_steps,
"project": aiplatform_initializer.global_config.project,
# TODO(b/275444096): Remove the explicit location once tuning can happen in all regions
# "location": aiplatform_initializer.global_config.location,
@@ -1177,8 +1955,6 @@ def _launch_tuning_job_on_jsonl_data(
"large_model_reference": model_id,
"model_display_name": model_display_name,
}
- if learning_rate:
- pipeline_arguments["learning_rate"] = learning_rate
if dataset_name_or_uri.startswith("projects/"):
pipeline_arguments["dataset_name"] = dataset_name_or_uri
@@ -1188,6 +1964,7 @@ def _launch_tuning_job_on_jsonl_data(
pipeline_arguments[
"encryption_spec_key_name"
] = aiplatform_initializer.global_config.encryption_spec_key_name
+ pipeline_arguments.update(tuning_parameters)
job = aiplatform.PipelineJob(
template_path=tuning_pipeline_uri,
display_name=None,
diff --git a/vertexai/preview/language_models.py b/vertexai/preview/language_models.py
index 6ecf2a6d54..29c38e425d 100644
--- a/vertexai/preview/language_models.py
+++ b/vertexai/preview/language_models.py
@@ -26,9 +26,21 @@
CodeChatSession,
InputOutputTextPair,
TextEmbedding,
+ TextEmbeddingInput,
TextGenerationResponse,
+ TuningEvaluationSpec,
)
+from vertexai.language_models._evaluatable_language_models import (
+ EvaluationTextGenerationSpec,
+ EvaluationTextSummarizationSpec,
+ EvaluationQuestionAnsweringSpec,
+ EvaluationTextClassificationSpec,
+ EvaluationClassificationMetric,
+ EvaluationMetric,
+)
+
+
ChatModel = _PreviewChatModel
CodeChatModel = _PreviewCodeChatModel
CodeGenerationModel = _PreviewCodeGenerationModel
@@ -42,9 +54,17 @@
"CodeChatModel",
"CodeChatSession",
"CodeGenerationModel",
+ "EvaluationClassificationMetric",
+ "EvaluationMetric",
+ "EvaluationTextGenerationSpec",
+ "EvaluationTextSummarizationSpec",
+ "EvaluationQuestionAnsweringSpec",
+ "EvaluationTextClassificationSpec",
"InputOutputTextPair",
"TextEmbedding",
+ "TextEmbeddingInput",
"TextEmbeddingModel",
"TextGenerationModel",
"TextGenerationResponse",
+ "TuningEvaluationSpec",
]
diff --git a/vertexai/preview/vision_models.py b/vertexai/preview/vision_models.py
index fb3a32fd5f..67290e6736 100644
--- a/vertexai/preview/vision_models.py
+++ b/vertexai/preview/vision_models.py
@@ -16,16 +16,22 @@
from vertexai.vision_models._vision_models import (
Image,
+ ImageGenerationModel,
+ ImageGenerationResponse,
ImageCaptioningModel,
ImageQnAModel,
+ GeneratedImage,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
)
__all__ = [
"Image",
+ "ImageGenerationModel",
+ "ImageGenerationResponse",
"ImageCaptioningModel",
"ImageQnAModel",
+ "GeneratedImage",
"MultiModalEmbeddingModel",
"MultiModalEmbeddingResponse",
]
diff --git a/vertexai/vision_models/__init__.py b/vertexai/vision_models/__init__.py
new file mode 100644
index 0000000000..fb3a32fd5f
--- /dev/null
+++ b/vertexai/vision_models/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Classes for working with vision models."""
+
+from vertexai.vision_models._vision_models import (
+ Image,
+ ImageCaptioningModel,
+ ImageQnAModel,
+ MultiModalEmbeddingModel,
+ MultiModalEmbeddingResponse,
+)
+
+__all__ = [
+ "Image",
+ "ImageCaptioningModel",
+ "ImageQnAModel",
+ "MultiModalEmbeddingModel",
+ "MultiModalEmbeddingResponse",
+]
diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py
index 6b6ca8f695..755f4b8ff9 100644
--- a/vertexai/vision_models/_vision_models.py
+++ b/vertexai/vision_models/_vision_models.py
@@ -16,9 +16,12 @@
import base64
import dataclasses
+import hashlib
import io
+import json
import pathlib
-from typing import Any, List, Optional
+import typing
+from typing import Any, Dict, List, Optional, Union
from vertexai._model_garden import _model_garden_models
@@ -34,6 +37,9 @@
PIL_Image = None
+_SUPPORTED_UPSCALING_SIZES = [2048, 4096]
+
+
class Image:
"""Image."""
@@ -100,6 +106,378 @@ def _as_base64_string(self) -> str:
return base64.b64encode(self._image_bytes).decode("ascii")
+class ImageGenerationModel(
+ _model_garden_models._ModelGardenModel # pylint: disable=protected-access
+):
+ """Generates images from text prompt.
+
+ Examples::
+
+ model = ImageGenerationModel.from_pretrained("imagegeneration@002")
+ response = model.generate_images(
+ prompt="Astronaut riding a horse",
+ # Optional:
+ number_of_images=1,
+ width=1024,
+ width=768,
+ seed=0,
+ )
+ response[0].show()
+ response[0].save("image1.png")
+ """
+
+ _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_generative_model_1.0.0.yaml"
+
+ def _generate_images(
+ self,
+ prompt: str,
+ *,
+ negative_prompt: Optional[str] = None,
+ number_of_images: int = 1,
+ width: Optional[int] = None,
+ height: Optional[int] = None,
+ guidance_scale: Optional[float] = None,
+ seed: Optional[int] = None,
+ base_image: Optional["Image"] = None,
+ mask: Optional["Image"] = None,
+ ) -> "ImageGenerationResponse":
+ """Generates images from text prompt.
+
+ Args:
+ prompt: Text prompt for the image.
+ negative_prompt: A description of what you want to omit in
+ the generated images.
+ number_of_images: Number of images to generate. Range: 1..8.
+ width: Width of the image. One of the sizes must be 256 or 1024.
+ height: Height of the image. One of the sizes must be 256 or 1024.
+ guidance_scale: Controls the strength of the prompt.
+ Suggested values are:
+ * 0-9 (low strength)
+ * 10-20 (medium strength)
+ * 21+ (high strength)
+ seed: Image generation random seed.
+ base_image: Base image to use for the image generation.
+ mask: Mask for the base image.
+
+ Returns:
+ An `ImageGenerationResponse` object.
+ """
+ # Note: Only a single prompt is supported by the service.
+ instance = {"prompt": prompt}
+ shared_generation_parameters = {
+ "prompt": prompt,
+ # b/295946075 The service stopped supporting image sizes.
+ # "width": width,
+ # "height": height,
+ "number_of_images_in_batch": number_of_images,
+ }
+
+ if negative_prompt:
+ instance["negativePrompt"] = negative_prompt
+ shared_generation_parameters["negative_prompt"] = negative_prompt
+
+ if base_image:
+ base_image_base64 = (
+ base_image._as_base64_string()
+ ) # pylint: disable=protected-access
+ instance["image"] = {"bytesBase64Encoded": base_image_base64}
+ base_image_hash_hex = hashlib.sha1(
+ base_image._image_bytes # pylint: disable=protected-access
+ ).hexdigest()
+ shared_generation_parameters["base_image_hash"] = base_image_hash_hex
+
+ if mask:
+ mask_image_base64 = (
+ mask._as_base64_string()
+ ) # pylint: disable=protected-access
+ instance["mask"] = {"image": {"bytesBase64Encoded": mask_image_base64}}
+ mask_image_hash_hex = hashlib.sha1(
+ mask._image_bytes # pylint: disable=protected-access
+ ).hexdigest()
+ shared_generation_parameters["mask_hash"] = mask_image_hash_hex
+
+ parameters = {}
+ max_size = max(width or 0, height or 0) or None
+ if max_size:
+ # Note: The size needs to be a string
+ parameters["sampleImageSize"] = str(max_size)
+ if height is not None and width is not None and height != width:
+ parameters["aspectRatio"] = f"{width}:{height}"
+
+ parameters["sampleCount"] = number_of_images
+
+ if seed is not None:
+ # Note: String seed and numerical seed give different results
+ parameters["seed"] = seed
+ shared_generation_parameters["seed"] = seed
+
+ if guidance_scale is not None:
+ parameters["guidanceScale"] = guidance_scale
+ shared_generation_parameters["guidance_scale"] = guidance_scale
+
+ response = self._endpoint.predict(
+ instances=[instance],
+ parameters=parameters,
+ )
+
+ generated_images: List["GeneratedImage"] = []
+ for idx, prediction in enumerate(response.predictions):
+ image_bytes = base64.b64decode(prediction["bytesBase64Encoded"])
+ generation_parameters = dict(shared_generation_parameters)
+ generation_parameters["index_of_image_in_batch"] = idx
+ generated_image = GeneratedImage(
+ image_bytes=image_bytes,
+ generation_parameters=generation_parameters,
+ )
+ generated_images.append(generated_image)
+
+ return ImageGenerationResponse(images=generated_images)
+
+ def generate_images(
+ self,
+ prompt: str,
+ *,
+ negative_prompt: Optional[str] = None,
+ number_of_images: int = 1,
+ guidance_scale: Optional[float] = None,
+ seed: Optional[int] = None,
+ ) -> "ImageGenerationResponse":
+ """Generates images from text prompt.
+
+ Args:
+ prompt: Text prompt for the image.
+ negative_prompt: A description of what you want to omit in
+ the generated images.
+ number_of_images: Number of images to generate. Range: 1..8.
+ guidance_scale: Controls the strength of the prompt.
+ Suggested values are:
+ * 0-9 (low strength)
+ * 10-20 (medium strength)
+ * 21+ (high strength)
+ seed: Image generation random seed.
+
+ Returns:
+ An `ImageGenerationResponse` object.
+ """
+ return self._generate_images(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ number_of_images=number_of_images,
+ # b/295946075 The service stopped supporting image sizes.
+ width=None,
+ height=None,
+ guidance_scale=guidance_scale,
+ seed=seed,
+ )
+
+ def edit_image(
+ self,
+ *,
+ prompt: str,
+ base_image: "Image",
+ mask: Optional["Image"] = None,
+ negative_prompt: Optional[str] = None,
+ number_of_images: int = 1,
+ guidance_scale: Optional[float] = None,
+ seed: Optional[int] = None,
+ ) -> "ImageGenerationResponse":
+ """Edits an existing image based on text prompt.
+
+ Args:
+ prompt: Text prompt for the image.
+ base_image: Base image from which to generate the new image.
+ mask: Mask for the base image.
+ negative_prompt: A description of what you want to omit in
+ the generated images.
+ number_of_images: Number of images to generate. Range: 1..8.
+ guidance_scale: Controls the strength of the prompt.
+ Suggested values are:
+ * 0-9 (low strength)
+ * 10-20 (medium strength)
+ * 21+ (high strength)
+ seed: Image generation random seed.
+
+ Returns:
+ An `ImageGenerationResponse` object.
+ """
+ return self._generate_images(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ number_of_images=number_of_images,
+ guidance_scale=guidance_scale,
+ seed=seed,
+ base_image=base_image,
+ mask=mask,
+ )
+
+ def upscale_image(
+ self,
+ image: Union["Image", "GeneratedImage"],
+ new_size: Optional[int] = 2048,
+ ) -> "Image":
+ """Upscales an image.
+
+ This supports upscaling images generated through the `generate_images()` method,
+ or upscaling a new image that is 1024x1024.
+
+ Examples::
+
+ # Upscale a generated image
+ model = ImageGenerationModel.from_pretrained("imagegeneration@002")
+ response = model.generate_images(
+ prompt="Astronaut riding a horse",
+ )
+ model.upscale_image(image=response[0])
+
+ # Upscale a new 1024x1024 image
+ my_image = Image.load_from_file("my-image.png")
+ model.upscale_image(image=my_image)
+
+ Args:
+ image (Union[GeneratedImage, Image]):
+ Required. The generated image to upscale.
+ new_size (int):
+ The size of the biggest dimension of the upscaled image. Only 2048 and 4096 are currently
+ supported. Results in a 2048x2048 or 4096x4096 image. Defaults to 2048 if not provided.
+
+ Returns:
+ An `Image` object.
+ """
+
+ # Currently this method only supports 1024x1024 images
+ if image._size[0] != 1024 and image._size[1] != 1024:
+ raise ValueError(
+ "Upscaling is currently only supported on images that are 1024x1024."
+ )
+
+ if new_size not in _SUPPORTED_UPSCALING_SIZES:
+ raise ValueError(
+ f"Only the folowing square upscaling sizes are currently supported: {_SUPPORTED_UPSCALING_SIZES}."
+ )
+
+ instance = {
+ "prompt": "",
+ "image": {"bytesBase64Encoded": image._as_base64_string()},
+ }
+
+ parameters = {
+ "sampleImageSize": str(new_size),
+ "sampleCount": 1,
+ "mode": "upscale",
+ }
+
+ response = self._endpoint.predict(
+ instances=[instance],
+ parameters=parameters,
+ )
+
+ upscaled_image = response.predictions[0]
+
+ if isinstance(image, GeneratedImage):
+ generation_parameters = image.generation_parameters
+
+ else:
+ generation_parameters = {}
+
+ generation_parameters["upscaled_image_size"] = new_size
+
+ return GeneratedImage(
+ image_bytes=base64.b64decode(upscaled_image["bytesBase64Encoded"]),
+ generation_parameters=generation_parameters,
+ )
+
+
+@dataclasses.dataclass
+class ImageGenerationResponse:
+ """Image generation response.
+
+ Attributes:
+ images: The list of generated images.
+ """
+
+ images: List["GeneratedImage"]
+
+ def __iter__(self) -> typing.Iterator["GeneratedImage"]:
+ """Iterates through the generated images."""
+ yield from self.images
+
+ def __getitem__(self, idx: int) -> "GeneratedImage":
+ """Gets the generated image by index."""
+ return self.images[idx]
+
+
+_EXIF_USER_COMMENT_TAG_IDX = 0x9286
+_IMAGE_GENERATION_PARAMETERS_EXIF_KEY = (
+ "google.cloud.vertexai.image_generation.image_generation_parameters"
+)
+
+
+class GeneratedImage(Image):
+ """Generated image."""
+
+ def __init__(
+ self,
+ image_bytes: bytes,
+ generation_parameters: Dict[str, Any],
+ ):
+ """Creates a `GeneratedImage` object.
+
+ Args:
+ image_bytes: Image file bytes. Image can be in PNG or JPEG format.
+ generation_parameters: Image generation parameter values.
+ """
+ super().__init__(image_bytes=image_bytes)
+ self._generation_parameters = generation_parameters
+
+ @property
+ def generation_parameters(self):
+ """Image generation parameters as a dictionary."""
+ return self._generation_parameters
+
+ @staticmethod
+ def load_from_file(location: str) -> "GeneratedImage":
+ """Loads image from file.
+
+ Args:
+ location: Local path from where to load the image.
+
+ Returns:
+ Loaded image as a `GeneratedImage` object.
+ """
+ base_image = Image.load_from_file(location=location)
+ exif = base_image._pil_image.getexif() # pylint: disable=protected-access
+ exif_comment_dict = json.loads(exif[_EXIF_USER_COMMENT_TAG_IDX])
+ generation_parameters = exif_comment_dict[_IMAGE_GENERATION_PARAMETERS_EXIF_KEY]
+ return GeneratedImage(
+ image_bytes=base_image._image_bytes, # pylint: disable=protected-access
+ generation_parameters=generation_parameters,
+ )
+
+ def save(self, location: str, include_generation_parameters: bool = True):
+ """Saves image to a file.
+
+ Args:
+ location: Local path where to save the image.
+ include_generation_parameters: Whether to include the image
+ generation parameters in the image's EXIF metadata.
+ """
+ if include_generation_parameters:
+ if not self._generation_parameters:
+ raise ValueError("Image does not have generation parameters.")
+ if not PIL_Image:
+ raise ValueError(
+ "The PIL module is required for saving generation parameters."
+ )
+
+ exif = self._pil_image.getexif()
+ exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps(
+ {_IMAGE_GENERATION_PARAMETERS_EXIF_KEY: self._generation_parameters}
+ )
+ self._pil_image.save(location, exif=exif)
+ else:
+ super().save(location=location)
+
+
class ImageCaptioningModel(
_model_garden_models._ModelGardenModel # pylint: disable=protected-access
):