diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 307453a670..e180a38341 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.24.1" + ".": "1.25.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index ebd1f3be40..51f86cb55a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,24 @@ # Changelog +## [1.25.0](https://github.com/googleapis/python-aiplatform/compare/v1.24.1...v1.25.0) (2023-05-09) + + +### Features + +* Add support for Large Language Models ([866c6aa](https://github.com/googleapis/python-aiplatform/commit/866c6aaf72b9a7a5f6155665f574cc11cf8075f4)) +* Add default TensorBoard support. ([fa7d3a0](https://github.com/googleapis/python-aiplatform/commit/fa7d3a0e3cd5040eb4ab1c3b0df4e494dc84bac3)) +* Add support for find_neighbors/read_index_datapoints in matching engine public endpoint ([e3a87f0](https://github.com/googleapis/python-aiplatform/commit/e3a87f04abf013341fe4f655b96405e27228ffdb)) +* Added the new root `vertexai` package ([fbd03b1](https://github.com/googleapis/python-aiplatform/commit/fbd03b15e9b71cbeeaebc868745a36c892b55c8f)) + + +### Bug Fixes + +* EntityType RPC update returns the updated EntityType - not an LRO. ([8f9c714](https://github.com/googleapis/python-aiplatform/commit/8f9c7144c152e105924d87abb30aa734af376486)) +* Fix default AutoML Forecasting transformations list. ([77b89c0](https://github.com/googleapis/python-aiplatform/commit/77b89c0151ce3647b8fac8f4e8b6a7f7c07a1192)) +* Fix type hints for `Prediction.predictions`. ([56518f1](https://github.com/googleapis/python-aiplatform/commit/56518f166215761354aba43d78301a11d198daf5)) +* Removed parameter Resume, due to causing confusion and errors. ([c82e0b5](https://github.com/googleapis/python-aiplatform/commit/c82e0b5fb74fe9ba15f9d0f14a441349499ee257)) + ## [1.24.1](https://github.com/googleapis/python-aiplatform/compare/v1.24.0...v1.24.1) (2023-04-21) diff --git a/docs/aiplatform/services.rst b/docs/aiplatform/services.rst index a1a8e2af42..89a65d4b46 100644 --- a/docs/aiplatform/services.rst +++ b/docs/aiplatform/services.rst @@ -5,3 +5,6 @@ Google Cloud Aiplatform SDK :members: :show-inheritance: :inherited-members: + +.. autoclass:: google.cloud.aiplatform.metadata.schema.google.artifact_schema.ExperimentModel + :members: \ No newline at end of file diff --git a/docs/aiplatform_v1beta1/model_garden_service.rst b/docs/aiplatform_v1beta1/model_garden_service.rst new file mode 100644 index 0000000000..3a4a28e6f7 --- /dev/null +++ b/docs/aiplatform_v1beta1/model_garden_service.rst @@ -0,0 +1,6 @@ +ModelGardenService +------------------------------------ + +.. automodule:: google.cloud.aiplatform_v1beta1.services.model_garden_service + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index 39ef83dea8..1d546a84ee 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -14,6 +14,7 @@ Services for Google Cloud Aiplatform v1beta1 API match_service metadata_service migration_service + model_garden_service model_service pipeline_service prediction_service diff --git a/docs/index.rst b/docs/index.rst index 73d0e542fd..99a912a20a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,6 +7,8 @@ API Reference .. toctree:: :maxdepth: 2 + vertexai/services + aiplatform/services aiplatform/types aiplatform/prediction diff --git a/docs/vertexai/services.rst b/docs/vertexai/services.rst new file mode 100644 index 0000000000..bdf5234132 --- /dev/null +++ b/docs/vertexai/services.rst @@ -0,0 +1,12 @@ +Vertex AI SDK +============================================= + +.. automodule:: vertexai + :members: + :show-inheritance: + :inherited-members: + +.. automodule:: vertexai.preview.language_models + :members: + :show-inheritance: + :inherited-members: diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index e97c578999..7538387da9 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -36,11 +36,13 @@ services.featurestore_service_client = services.featurestore_service_client_v1beta1 services.job_service_client = services.job_service_client_v1beta1 services.model_service_client = services.model_service_client_v1beta1 + services.model_garden_service_client = services.model_garden_service_client_v1beta1 services.pipeline_service_client = services.pipeline_service_client_v1beta1 services.prediction_service_client = services.prediction_service_client_v1beta1 services.specialist_pool_service_client = ( services.specialist_pool_service_client_v1beta1 ) + services.match_service_client = services.match_service_client_v1beta1 services.metadata_service_client = services.metadata_service_client_v1beta1 services.tensorboard_service_client = services.tensorboard_service_client_v1beta1 services.index_service_client = services.index_service_client_v1beta1 @@ -102,6 +104,7 @@ types.model_deployment_monitoring_job = ( types.model_deployment_monitoring_job_v1beta1 ) + types.model_garden_service = types.model_garden_service_v1beta1 types.model_monitoring = types.model_monitoring_v1beta1 types.model_service = types.model_service_v1beta1 types.operation = types.operation_v1beta1 @@ -110,6 +113,7 @@ types.pipeline_service = types.pipeline_service_v1beta1 types.pipeline_state = types.pipeline_state_v1beta1 types.prediction_service = types.prediction_service_v1beta1 + types.publisher_model = types.publisher_model_v1beta1 types.specialist_pool = types.specialist_pool_v1beta1 types.specialist_pool_service = types.specialist_pool_service_v1beta1 types.study = types.study_v1beta1 diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py index 25ff66b1d1..4dc073c3b2 100644 --- a/google/cloud/aiplatform/compat/services/__init__.py +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -39,9 +39,15 @@ from google.cloud.aiplatform_v1beta1.services.job_service import ( client as job_service_client_v1beta1, ) +from google.cloud.aiplatform_v1beta1.services.match_service import ( + client as match_service_client_v1beta1, +) from google.cloud.aiplatform_v1beta1.services.metadata_service import ( client as metadata_service_client_v1beta1, ) +from google.cloud.aiplatform_v1beta1.services.model_garden_service import ( + client as model_garden_service_client_v1beta1, +) from google.cloud.aiplatform_v1beta1.services.model_service import ( client as model_service_client_v1beta1, ) @@ -129,6 +135,8 @@ index_service_client_v1beta1, index_endpoint_service_client_v1beta1, job_service_client_v1beta1, + match_service_client_v1beta1, + model_garden_service_client_v1beta1, model_service_client_v1beta1, pipeline_service_client_v1beta1, prediction_service_client_v1beta1, diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index 0356bac918..dcc0fa4e62 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -57,6 +57,7 @@ lineage_subgraph as lineage_subgraph_v1beta1, machine_resources as machine_resources_v1beta1, manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1, + match_service as match_service_v1beta1, metadata_schema as metadata_schema_v1beta1, metadata_service as metadata_service_v1beta1, metadata_store as metadata_store_v1beta1, @@ -64,6 +65,7 @@ model_evaluation as model_evaluation_v1beta1, model_evaluation_slice as model_evaluation_slice_v1beta1, model_deployment_monitoring_job as model_deployment_monitoring_job_v1beta1, + model_garden_service as model_garden_service_v1beta1, model_service as model_service_v1beta1, model_monitoring as model_monitoring_v1beta1, operation as operation_v1beta1, @@ -72,6 +74,7 @@ pipeline_service as pipeline_service_v1beta1, pipeline_state as pipeline_state_v1beta1, prediction_service as prediction_service_v1beta1, + publisher_model as publisher_model_v1beta1, specialist_pool as specialist_pool_v1beta1, specialist_pool_service as specialist_pool_service_v1beta1, study as study_v1beta1, @@ -203,7 +206,7 @@ model_service_v1, model_monitoring_v1, operation_v1, - pipeline_failure_policy_v1beta1, + pipeline_failure_policy_v1, pipeline_job_v1, pipeline_service_v1, pipeline_state_v1, @@ -218,6 +221,8 @@ tensorboard_time_series_v1, training_pipeline_v1, types_v1, + study_v1, + vizier_service_v1, # v1beta1 accelerator_type_v1beta1, annotation_v1beta1, @@ -260,6 +265,7 @@ matching_engine_deployed_index_ref_v1beta1, index_v1beta1, index_endpoint_v1beta1, + match_service_v1beta1, metadata_service_v1beta1, metadata_schema_v1beta1, metadata_store_v1beta1, @@ -267,6 +273,7 @@ model_evaluation_v1beta1, model_evaluation_slice_v1beta1, model_deployment_monitoring_job_v1beta1, + model_garden_service_v1beta1, model_service_v1beta1, model_monitoring_v1beta1, operation_v1beta1, @@ -275,8 +282,10 @@ pipeline_service_v1beta1, pipeline_state_v1beta1, prediction_service_v1beta1, + publisher_model_v1beta1, specialist_pool_v1beta1, specialist_pool_service_v1beta1, + study_v1beta1, tensorboard_v1beta1, tensorboard_data_v1beta1, tensorboard_experiment_v1beta1, @@ -285,4 +294,5 @@ tensorboard_time_series_v1beta1, training_pipeline_v1beta1, types_v1beta1, + vizier_service_v1beta1, ) diff --git a/google/cloud/aiplatform/featurestore/_entity_type.py b/google/cloud/aiplatform/featurestore/_entity_type.py index 859e5fc7ce..fee48a05cb 100644 --- a/google/cloud/aiplatform/featurestore/_entity_type.py +++ b/google/cloud/aiplatform/featurestore/_entity_type.py @@ -249,18 +249,15 @@ def update( self, ) - update_entity_type_lro = self.api_client.update_entity_type( + updated_entity_type = self.api_client.update_entity_type( entity_type=gapic_entity_type, update_mask=update_mask, metadata=request_metadata, timeout=update_request_timeout, ) - _LOGGER.log_action_started_against_resource_with_lro( - "Update", "entityType", self.__class__, update_entity_type_lro - ) - - update_entity_type_lro.result() + # Update underlying resource with response data. + self._gca_resource = updated_entity_type _LOGGER.log_action_completed_against_resource("entityType", "updated", self) diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 93cf8d730d..dbd3e9bda1 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -279,6 +279,7 @@ def get_client_options( location_override: Optional[str] = None, prediction_client: bool = False, api_base_path_override: Optional[str] = None, + api_path_override: Optional[str] = None, ) -> client_options.ClientOptions: """Creates GAPIC client_options using location and type. @@ -289,6 +290,7 @@ def get_client_options( Vertex AI. prediction_client (str): Optional. flag to use a prediction endpoint. api_base_path_override (str): Optional. Override default API base path. + api_path_override (str): Optional. Override default api path. Returns: clients_options (google.api_core.client_options.ClientOptions): A ClientOptions object set with regionalized API endpoint, i.e. @@ -311,9 +313,12 @@ def get_client_options( else constants.API_BASE_PATH ) - return client_options.ClientOptions( - api_endpoint=f"{region}-{service_base_path}" + api_endpoint = ( + f"{region}-{service_base_path}" + if not api_path_override + else api_path_override ) + return client_options.ClientOptions(api_endpoint=api_endpoint) def common_location_path( self, project: Optional[str] = None, location: Optional[str] = None @@ -345,6 +350,7 @@ def create_client( location_override: Optional[str] = None, prediction_client: bool = False, api_base_path_override: Optional[str] = None, + api_path_override: Optional[str] = None, appended_user_agent: Optional[List[str]] = None, ) -> utils.VertexAiServiceClientWithOverride: """Instantiates a given VertexAiServiceClient with optional @@ -358,6 +364,7 @@ def create_client( location_override (str): Optional. location override. prediction_client (str): Optional. flag to use a prediction endpoint. api_base_path_override (str): Optional. Override default api base path. + api_path_override (str): Optional. Override default api path. appended_user_agent (List[str]): Optional. User agent appended in the client info. If more than one, it will be separated by spaces. @@ -383,6 +390,7 @@ def create_client( location_override=location_override, prediction_client=prediction_client, api_base_path_override=api_base_path_override, + api_path_override=api_path_override, ), "client_info": client_info, } diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 9d5aba2db2..4b04b01688 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -26,6 +26,8 @@ from google.cloud.aiplatform.compat.types import ( machine_resources as gca_machine_resources_compat, matching_engine_index_endpoint as gca_matching_engine_index_endpoint, + match_service_v1beta1 as gca_match_service_v1beta1, + index_v1beta1 as gca_index_v1beta1, ) from google.cloud.aiplatform.matching_engine._protos import match_service_pb2 from google.cloud.aiplatform.matching_engine._protos import ( @@ -127,6 +129,9 @@ def __init__( ) self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name) + if self.public_endpoint_domain_name: + self._public_match_client = self._instantiate_public_match_client() + @classmethod def create( cls, @@ -344,6 +349,22 @@ def _create( return index_obj + def _instantiate_public_match_client( + self, + ) -> utils.MatchClientWithOverride: + """Helper method to instantiates match client with optional + overrides for this endpoint. + Returns: + match_client (match_service_client.MatchServiceClient): + Initialized match client with optional overrides. + """ + return initializer.global_config.create_client( + client_class=utils.MatchClientWithOverride, + credentials=self.credentials, + location_override=self.location, + api_path_override=self.public_endpoint_domain_name, + ) + @property def public_endpoint_domain_name(self) -> Optional[str]: """Public endpoint DNS name.""" @@ -928,6 +949,124 @@ def description(self) -> str: self._assert_gca_resource_is_available() return self._gca_resource.description + def find_neighbors( + self, + *, + deployed_index_id: str, + queries: List[List[float]], + num_neighbors: int = 10, + filter: Optional[List[Namespace]] = [], + ) -> List[List[MatchNeighbor]]: + """Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint. + + ``` + Example usage: + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id' + ) + my_index_endpoint.find_neighbors(deployed_index_id="public_test1", queries= [[1, 1]],) + ``` + Args: + deployed_index_id (str): + Required. The ID of the DeployedIndex to match the queries against. + queries (List[List[float]]): + Required. A list of queries. Each query is a list of floats, representing a single embedding. + num_neighbors (int): + Required. The number of nearest neighbors to be retrieved from database for + each query. + filter (List[Namespace]): + Optional. A list of Namespaces for filtering the matching results. + For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints + that satisfy "red color" but not include datapoints with "squared shape". + Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + Returns: + List[List[MatchNeighbor]] - A list of nearest neighbors for each query. + """ + + if not self._public_match_client: + raise ValueError( + "Please make sure index has been deployed to public endpoint, and follow the example usage to call this method." + ) + + # Create the FindNeighbors request + find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest() + find_neighbors_request.index_endpoint = self.resource_name + find_neighbors_request.deployed_index_id = deployed_index_id + + for query in queries: + find_neighbors_query = ( + gca_match_service_v1beta1.FindNeighborsRequest.Query() + ) + find_neighbors_query.neighbor_count = num_neighbors + datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query) + for namespace in filter: + restrict = gca_index_v1beta1.IndexDatapoint.Restriction() + restrict.namespace = namespace.name + restrict.allow_list.extend(namespace.allow_tokens) + restrict.deny_list.extend(namespace.deny_tokens) + datapoint.restricts.append(restrict) + find_neighbors_query.datapoint = datapoint + find_neighbors_request.queries.append(find_neighbors_query) + + response = self._public_match_client.find_neighbors(find_neighbors_request) + + # Wrap the results in MatchNeighbor objects and return + return [ + [ + MatchNeighbor( + id=neighbor.datapoint.datapoint_id, distance=neighbor.distance + ) + for neighbor in embedding_neighbors.neighbors + ] + for embedding_neighbors in response.nearest_neighbors + ] + + def read_index_datapoints( + self, + *, + deployed_index_id: str, + ids: List[str] = [], + ) -> List[gca_index_v1beta1.IndexDatapoint]: + """Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint. + + ``` + Example Usage: + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id' + ) + my_index_endpoint.read_index_datapoints(deployed_index_id="public_test1", ids= ["606431", "896688"],) + ``` + + Args: + deployed_index_id (str): + Required. The ID of the DeployedIndex to match the queries against. + ids (List[str]): + Required. IDs of the datapoints to be searched for. + Returns: + List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs. + """ + if not self._public_match_client: + raise ValueError( + "Please make sure index has been deployed to public endpoint, and follow the example usage to call this method." + ) + + # Create the ReadIndexDatapoints request + read_index_datapoints_request = ( + gca_match_service_v1beta1.ReadIndexDatapointsRequest() + ) + read_index_datapoints_request.index_endpoint = self.resource_name + read_index_datapoints_request.deployed_index_id = deployed_index_id + + for id in ids: + read_index_datapoints_request.ids.append(id) + + response = self._public_match_client.read_index_datapoints( + read_index_datapoints_request + ) + + # Wrap the results and return + return response.datapoints + def match( self, deployed_index_id: str, diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index 5f8a45f427..e63d5e671b 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -15,6 +15,7 @@ # limitations under the License. # +import datetime import logging import os from typing import Dict, Union, Optional, Any, List @@ -61,6 +62,24 @@ def _get_experiment_schema_version() -> str: return constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT] +def _get_or_create_default_tensorboard() -> tensorboard_resource.Tensorboard: + """Helper method to get the default TensorBoard instance if already exists, or create a default TensorBoard instance. + + Returns: + tensorboard_resource.Tensorboard: the default TensorBoard instance. + """ + tensorboards = tensorboard_resource.Tensorboard.list(filter="is_default=true") + if tensorboards: + return tensorboards[0] + else: + default_tensorboard = tensorboard_resource.Tensorboard.create( + display_name="Default Tensorboard " + + datetime.datetime.now().isoformat(sep=" "), + is_default=True, + ) + return default_tensorboard + + # Legacy Experiment tracking # Maintaining creation APIs for backwards compatibility testing class _LegacyExperimentService: @@ -268,7 +287,11 @@ def set_experiment( experiment_name=experiment, description=description ) - backing_tb = backing_tensorboard or self._global_tensorboard + backing_tb = ( + backing_tensorboard + or self._global_tensorboard + or _get_or_create_default_tensorboard() + ) current_backing_tb = experiment.backing_tensorboard_resource_name @@ -277,16 +300,6 @@ def set_experiment( self._experiment = experiment - if ( - not current_backing_tb - and not backing_tb - and autologging_utils._is_autologging_enabled() - ): - logging.warning( - "Disabling autologging since the current Experiment doesn't have a backing Tensorboard." - ) - self.autolog(disable=True) - def set_tensorboard( self, tensorboard: Union[ diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 893770bffa..9b46420c32 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -147,7 +147,7 @@ class Prediction(NamedTuple): of elements as instances to be explained. Default is None. """ - predictions: List[Dict[str, Any]] + predictions: List[Any] deployed_model_id: str model_version_id: Optional[str] = None model_resource_name: Optional[str] = None diff --git a/google/cloud/aiplatform/preview/_publisher_model.py b/google/cloud/aiplatform/preview/_publisher_model.py new file mode 100644 index 0000000000..1a27c3f469 --- /dev/null +++ b/google/cloud/aiplatform/preview/_publisher_model.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Optional + +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import utils + + +class _PublisherModel(base.VertexAiResourceNoun): + """Publisher Model Resource for Vertex AI.""" + + client_class = utils.ModelGardenClientWithOverride + + _resource_noun = "publisher_model" + _getter_method = "get_publisher_model" + _delete_method = None + _parse_resource_name_method = "parse_publisher_model_path" + _format_resource_name_method = "publisher_model_path" + + def __init__( + self, + resource_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing PublisherModel resource given a resource name or model garden id. + + Args: + resource_name (str): + Required. A fully-qualified PublisherModel resource name or + model garden id. Format: + `publishers/{publisher}/models/{publisher_model}` or + `{publisher}/{publisher_model}`. + project (str): + Optional. Project to retrieve the resource from. If not set, + project set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve the resource from. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve the resource. + Overrides credentials set in aiplatform.init. + """ + + super().__init__(project=project, location=location, credentials=credentials) + + if self._parse_resource_name(resource_name): + full_resource_name = resource_name + else: + m = re.match(r"^(?P.+?)/(?P.+?)$", resource_name) + if m: + full_resource_name = self._format_resource_name(**m.groupdict()) + else: + raise ValueError( + f"`{resource_name}` is not a valid PublisherModel resource " + "name or model garden id." + ) + + self._gca_resource = getattr(self.api_client, self._getter_method)( + name=full_resource_name, retry=base._DEFAULT_RETRY + ) diff --git a/google/cloud/aiplatform/tensorboard/tensorboard_resource.py b/google/cloud/aiplatform/tensorboard/tensorboard_resource.py index 76e89ca6bd..fcccb3b2b8 100644 --- a/google/cloud/aiplatform/tensorboard/tensorboard_resource.py +++ b/google/cloud/aiplatform/tensorboard/tensorboard_resource.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2021 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,14 +24,18 @@ from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils -from google.cloud.aiplatform.compat.types import tensorboard as gca_tensorboard +from google.cloud.aiplatform.compat.types import ( + tensorboard as gca_tensorboard, +) from google.cloud.aiplatform.compat.types import ( tensorboard_data as gca_tensorboard_data, ) from google.cloud.aiplatform.compat.types import ( tensorboard_experiment as gca_tensorboard_experiment, ) -from google.cloud.aiplatform.compat.types import tensorboard_run as gca_tensorboard_run +from google.cloud.aiplatform.compat.types import ( + tensorboard_run as gca_tensorboard_run, +) from google.cloud.aiplatform.compat.types import ( tensorboard_service as gca_tensorboard_service, ) @@ -101,6 +105,7 @@ def create( request_metadata: Optional[Sequence[Tuple[str, str]]] = (), encryption_spec_key_name: Optional[str] = None, create_request_timeout: Optional[float] = None, + is_default: bool = False, ) -> "Tensorboard": """Creates a new tensorboard. @@ -156,6 +161,10 @@ def create( Overrides encryption_spec_key_name set in aiplatform.init. create_request_timeout (float): Optional. The timeout for the create request in seconds. + is_default (bool): + If the TensorBoard instance is default or not. The default + TensorBoard instance will be used by Experiment/ExperimentRun + when needed if no TensorBoard instance is explicitly specified. Returns: tensorboard (Tensorboard): @@ -182,6 +191,7 @@ def create( display_name=display_name, description=description, labels=labels, + is_default=is_default, encryption_spec=encryption_spec, ) @@ -210,6 +220,7 @@ def update( labels: Optional[Dict[str, str]] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), encryption_spec_key_name: Optional[str] = None, + is_default: Optional[bool] = None, ) -> "Tensorboard": """Updates an existing tensorboard. @@ -251,6 +262,11 @@ def update( If set, this Tensorboard and all sub-resources of this Tensorboard will be secured by this key. Overrides encryption_spec_key_name set in aiplatform.init. + is_default (bool): + Optional. If the TensorBoard instance is default or not. + The default TensorBoard instance will be used by + Experiment/ExperimentRun when needed if no TensorBoard instance + is explicitly specified. Returns: Tensorboard: The managed tensorboard resource. @@ -268,6 +284,9 @@ def update( utils.validate_labels(labels) update_mask.append("labels") + if is_default is not None: + update_mask.append("is_default") + encryption_spec = None if encryption_spec_key_name: encryption_spec = initializer.global_config.get_encryption_spec( @@ -282,6 +301,7 @@ def update( display_name=display_name, description=description, labels=labels, + is_default=is_default, encryption_spec=encryption_spec, ) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index a3a55b3b2f..beaf7c9093 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -2438,7 +2438,9 @@ def _run( ( self._column_transformations, column_names, - ) = dataset._get_default_column_transformations(target_column) + ) = column_transformations_utils.get_default_column_transformations( + dataset=dataset, target_column=target_column + ) _LOGGER.info( "The column transformation of type 'auto' was set for the following columns: %s." diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py index 14f2e79ef8..e0faf2460b 100644 --- a/google/cloud/aiplatform/utils/__init__.py +++ b/google/cloud/aiplatform/utils/__init__.py @@ -44,12 +44,14 @@ index_service_client_v1beta1, index_endpoint_service_client_v1beta1, job_service_client_v1beta1, + match_service_client_v1beta1, metadata_service_client_v1beta1, model_service_client_v1beta1, pipeline_service_client_v1beta1, prediction_service_client_v1beta1, tensorboard_service_client_v1beta1, vizier_service_client_v1beta1, + model_garden_service_client_v1beta1, ) from google.cloud.aiplatform.compat.services import ( dataset_service_client_v1, @@ -85,6 +87,7 @@ prediction_service_client_v1beta1.PredictionServiceClient, pipeline_service_client_v1beta1.PipelineServiceClient, job_service_client_v1beta1.JobServiceClient, + match_service_client_v1beta1.MatchServiceClient, metadata_service_client_v1beta1.MetadataServiceClient, tensorboard_service_client_v1beta1.TensorboardServiceClient, vizier_service_client_v1beta1.VizierServiceClient, @@ -598,6 +601,12 @@ class PredictionClientWithOverride(ClientWithOverride): ) +class MatchClientWithOverride(ClientWithOverride): + _is_temporary = False + _default_version = compat.V1BETA1 + _version_map = ((compat.V1BETA1, match_service_client_v1beta1.MatchServiceClient),) + + class MetadataClientWithOverride(ClientWithOverride): _is_temporary = True _default_version = compat.DEFAULT_VERSION @@ -625,6 +634,14 @@ class VizierClientWithOverride(ClientWithOverride): ) +class ModelGardenClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.V1BETA1 + _version_map = ( + (compat.V1BETA1, model_garden_service_client_v1beta1.ModelGardenServiceClient), + ) + + VertexAiServiceClientWithOverride = TypeVar( "VertexAiServiceClientWithOverride", DatasetClientWithOverride, @@ -632,12 +649,14 @@ class VizierClientWithOverride(ClientWithOverride): FeaturestoreClientWithOverride, JobClientWithOverride, ModelClientWithOverride, + MatchClientWithOverride, PipelineClientWithOverride, PipelineJobClientWithOverride, PredictionClientWithOverride, MetadataClientWithOverride, TensorboardClientWithOverride, VizierClientWithOverride, + ModelGardenClientWithOverride, ) diff --git a/google/cloud/aiplatform/utils/column_transformations_utils.py b/google/cloud/aiplatform/utils/column_transformations_utils.py index fe7c16983c..14dfead1f4 100644 --- a/google/cloud/aiplatform/utils/column_transformations_utils.py +++ b/google/cloud/aiplatform/utils/column_transformations_utils.py @@ -16,7 +16,6 @@ # from typing import Dict, List, Optional, Tuple -import warnings from google.cloud.aiplatform import datasets @@ -51,9 +50,9 @@ def get_default_column_transformations( def validate_and_get_column_transformations( - column_specs: Optional[Dict[str, str]], - column_transformations: Optional[List[Dict[str, Dict[str, str]]]], -) -> List[Dict[str, Dict[str, str]]]: + column_specs: Optional[Dict[str, str]] = None, + column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, +) -> Optional[List[Dict[str, Dict[str, str]]]]: """Validates column specs and transformations, then returns processed transformations. Args: @@ -91,21 +90,13 @@ def validate_and_get_column_transformations( # user populated transformations if column_transformations is not None and column_specs is not None: raise ValueError( - "Both column_transformations and column_specs were passed. Only one is allowed." + "Both column_transformations and column_specs were passed. Only " + "one is allowed." ) - if column_transformations is not None: - warnings.simplefilter("always", DeprecationWarning) - warnings.warn( - "consider using column_specs instead. column_transformations will be deprecated in the future.", - DeprecationWarning, - stacklevel=2, - ) - - return column_transformations elif column_specs is not None: return [ {transformation: {"column_name": column_name}} for column_name, transformation in column_specs.items() ] else: - return None + return column_transformations diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index 599983a982..bc48376025 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.24.1" +__version__ = "1.25.0" diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py index 9ac01c979e..5db0b76c98 100644 --- a/google/cloud/aiplatform_v1/__init__.py +++ b/google/cloud/aiplatform_v1/__init__.py @@ -114,6 +114,9 @@ from .types.endpoint_service import GetEndpointRequest from .types.endpoint_service import ListEndpointsRequest from .types.endpoint_service import ListEndpointsResponse +from .types.endpoint_service import MutateDeployedModelOperationMetadata +from .types.endpoint_service import MutateDeployedModelRequest +from .types.endpoint_service import MutateDeployedModelResponse from .types.endpoint_service import UndeployModelOperationMetadata from .types.endpoint_service import UndeployModelRequest from .types.endpoint_service import UndeployModelResponse @@ -925,6 +928,9 @@ "MutateDeployedIndexOperationMetadata", "MutateDeployedIndexRequest", "MutateDeployedIndexResponse", + "MutateDeployedModelOperationMetadata", + "MutateDeployedModelRequest", + "MutateDeployedModelResponse", "NasJob", "NasJobOutput", "NasJobSpec", diff --git a/google/cloud/aiplatform_v1/gapic_metadata.json b/google/cloud/aiplatform_v1/gapic_metadata.json index 3f995cf513..07ef390ae5 100644 --- a/google/cloud/aiplatform_v1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1/gapic_metadata.json @@ -169,6 +169,11 @@ "list_endpoints" ] }, + "MutateDeployedModel": { + "methods": [ + "mutate_deployed_model" + ] + }, "UndeployModel": { "methods": [ "undeploy_model" @@ -209,6 +214,11 @@ "list_endpoints" ] }, + "MutateDeployedModel": { + "methods": [ + "mutate_deployed_model" + ] + }, "UndeployModel": { "methods": [ "undeploy_model" diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index d5d02ecfa9..62d11f5cca 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -1162,6 +1162,165 @@ async def sample_undeploy_model(): # Done; return the response. return response + async def mutate_deployed_model( + self, + request: Optional[ + Union[endpoint_service.MutateDeployedModelRequest, dict] + ] = None, + *, + endpoint: Optional[str] = None, + deployed_model: Optional[gca_endpoint.DeployedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates an existing deployed model. Updatable fields include + ``min_replica_count``, ``max_replica_count``, + ``autoscaling_metric_specs``, ``disable_container_logging`` (v1 + only), and ``enable_container_logging`` (v1beta1 only). + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_mutate_deployed_model(): + # Create a client + client = aiplatform_v1.EndpointServiceAsyncClient() + + # Initialize request argument(s) + deployed_model = aiplatform_v1.DeployedModel() + deployed_model.dedicated_resources.min_replica_count = 1803 + deployed_model.model = "model_value" + + request = aiplatform_v1.MutateDeployedModelRequest( + endpoint="endpoint_value", + deployed_model=deployed_model, + ) + + # Make the request + operation = client.mutate_deployed_model(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.MutateDeployedModelRequest, dict]]): + The request object. Request message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel]. + endpoint (:class:`str`): + Required. The name of the Endpoint resource into which + to mutate a DeployedModel. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model (:class:`google.cloud.aiplatform_v1.types.DeployedModel`): + Required. The DeployedModel to be mutated within the + Endpoint. Only the following fields can be mutated: + + - ``min_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1.AutomaticResources] + - ``max_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1.AutomaticResources] + - [autoscaling_metric_specs][google.cloud.aiplatform.v1.DedicatedResources.autoscaling_metric_specs] + - ``disable_container_logging`` (v1 only) + - ``enable_container_logging`` (v1beta1 only) + + This corresponds to the ``deployed_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.MutateDeployedModelResponse` Response message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel]. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.MutateDeployedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.mutate_deployed_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + endpoint_service.MutateDeployedModelResponse, + metadata_type=endpoint_service.MutateDeployedModelOperationMetadata, + ) + + # Done; return the response. + return response + async def list_operations( self, request: Optional[operations_pb2.ListOperationsRequest] = None, diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 3a410f8ced..0794fa7938 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -1440,6 +1440,165 @@ def sample_undeploy_model(): # Done; return the response. return response + def mutate_deployed_model( + self, + request: Optional[ + Union[endpoint_service.MutateDeployedModelRequest, dict] + ] = None, + *, + endpoint: Optional[str] = None, + deployed_model: Optional[gca_endpoint.DeployedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Updates an existing deployed model. Updatable fields include + ``min_replica_count``, ``max_replica_count``, + ``autoscaling_metric_specs``, ``disable_container_logging`` (v1 + only), and ``enable_container_logging`` (v1beta1 only). + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_mutate_deployed_model(): + # Create a client + client = aiplatform_v1.EndpointServiceClient() + + # Initialize request argument(s) + deployed_model = aiplatform_v1.DeployedModel() + deployed_model.dedicated_resources.min_replica_count = 1803 + deployed_model.model = "model_value" + + request = aiplatform_v1.MutateDeployedModelRequest( + endpoint="endpoint_value", + deployed_model=deployed_model, + ) + + # Make the request + operation = client.mutate_deployed_model(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.MutateDeployedModelRequest, dict]): + The request object. Request message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel]. + endpoint (str): + Required. The name of the Endpoint resource into which + to mutate a DeployedModel. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model (google.cloud.aiplatform_v1.types.DeployedModel): + Required. The DeployedModel to be mutated within the + Endpoint. Only the following fields can be mutated: + + - ``min_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1.AutomaticResources] + - ``max_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1.AutomaticResources] + - [autoscaling_metric_specs][google.cloud.aiplatform.v1.DedicatedResources.autoscaling_metric_specs] + - ``disable_container_logging`` (v1 only) + - ``enable_container_logging`` (v1beta1 only) + + This corresponds to the ``deployed_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.MutateDeployedModelResponse` Response message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel]. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.MutateDeployedModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.MutateDeployedModelRequest): + request = endpoint_service.MutateDeployedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.mutate_deployed_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + endpoint_service.MutateDeployedModelResponse, + metadata_type=endpoint_service.MutateDeployedModelOperationMetadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "EndpointServiceClient": return self diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py index 5b068c606f..c7551e0d9e 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py @@ -166,6 +166,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.mutate_deployed_model: gapic_v1.method.wrap_method( + self.mutate_deployed_model, + default_timeout=None, + client_info=client_info, + ), } def close(self): @@ -248,6 +253,15 @@ def undeploy_model( ]: raise NotImplementedError() + @property + def mutate_deployed_model( + self, + ) -> Callable[ + [endpoint_service.MutateDeployedModelRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def list_operations( self, diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py index 87fdc4ea98..4f479c5b75 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py @@ -439,6 +439,37 @@ def undeploy_model( ) return self._stubs["undeploy_model"] + @property + def mutate_deployed_model( + self, + ) -> Callable[ + [endpoint_service.MutateDeployedModelRequest], operations_pb2.Operation + ]: + r"""Return a callable for the mutate deployed model method over gRPC. + + Updates an existing deployed model. Updatable fields include + ``min_replica_count``, ``max_replica_count``, + ``autoscaling_metric_specs``, ``disable_container_logging`` (v1 + only), and ``enable_container_logging`` (v1beta1 only). + + Returns: + Callable[[~.MutateDeployedModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "mutate_deployed_model" not in self._stubs: + self._stubs["mutate_deployed_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/MutateDeployedModel", + request_serializer=endpoint_service.MutateDeployedModelRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["mutate_deployed_model"] + def close(self): self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py index 9521232d05..3c54d71126 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py @@ -455,6 +455,38 @@ def undeploy_model( ) return self._stubs["undeploy_model"] + @property + def mutate_deployed_model( + self, + ) -> Callable[ + [endpoint_service.MutateDeployedModelRequest], + Awaitable[operations_pb2.Operation], + ]: + r"""Return a callable for the mutate deployed model method over gRPC. + + Updates an existing deployed model. Updatable fields include + ``min_replica_count``, ``max_replica_count``, + ``autoscaling_metric_specs``, ``disable_container_logging`` (v1 + only), and ``enable_container_logging`` (v1beta1 only). + + Returns: + Callable[[~.MutateDeployedModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "mutate_deployed_model" not in self._stubs: + self._stubs["mutate_deployed_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/MutateDeployedModel", + request_serializer=endpoint_service.MutateDeployedModelRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["mutate_deployed_model"] + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py index 2a816c9392..f9213d30ae 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py @@ -686,7 +686,7 @@ async def sample_update_featurestore(): - ``labels`` - ``online_serving_config.fixed_node_count`` - ``online_serving_config.scaling`` - - ``online_storage_ttl_days`` (available in Preview) + - ``online_storage_ttl_days`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this @@ -1348,7 +1348,7 @@ async def sample_update_entity_type(): - ``monitoring_config.import_features_analysis.anomaly_detection_baseline`` - ``monitoring_config.numerical_threshold_config.value`` - ``monitoring_config.categorical_threshold_config.value`` - - ``offline_storage_ttl_days`` (available in Preview) + - ``offline_storage_ttl_days`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1/services/featurestore_service/client.py index 373c36a9b2..211dfc8353 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/client.py @@ -954,7 +954,7 @@ def sample_update_featurestore(): - ``labels`` - ``online_serving_config.fixed_node_count`` - ``online_serving_config.scaling`` - - ``online_storage_ttl_days`` (available in Preview) + - ``online_storage_ttl_days`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this @@ -1616,7 +1616,7 @@ def sample_update_entity_type(): - ``monitoring_config.import_features_analysis.anomaly_detection_baseline`` - ``monitoring_config.numerical_threshold_config.value`` - ``monitoring_config.categorical_threshold_config.value`` - - ``offline_storage_ttl_days`` (available in Preview) + - ``offline_storage_ttl_days`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index 1ee18fcead..d6dca7db2f 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -103,6 +103,8 @@ class JobServiceAsyncClient: parse_batch_prediction_job_path = staticmethod( JobServiceClient.parse_batch_prediction_job_path ) + context_path = staticmethod(JobServiceClient.context_path) + parse_context_path = staticmethod(JobServiceClient.parse_context_path) custom_job_path = staticmethod(JobServiceClient.custom_job_path) parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index d02d845577..33939bc4f8 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -236,6 +236,30 @@ def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def context_path( + project: str, + location: str, + metadata_store: str, + context: str, + ) -> str: + """Returns a fully-qualified context string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + + @staticmethod + def parse_context_path(path: str) -> Dict[str, str]: + """Parses a context path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def custom_job_path( project: str, diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index c42427af48..b43dcee16e 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -230,40 +230,40 @@ def parse_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index ad54a6e047..b24c57f74e 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -104,6 +104,9 @@ GetEndpointRequest, ListEndpointsRequest, ListEndpointsResponse, + MutateDeployedModelOperationMetadata, + MutateDeployedModelRequest, + MutateDeployedModelResponse, UndeployModelOperationMetadata, UndeployModelRequest, UndeployModelResponse, @@ -705,6 +708,9 @@ "GetEndpointRequest", "ListEndpointsRequest", "ListEndpointsResponse", + "MutateDeployedModelOperationMetadata", + "MutateDeployedModelRequest", + "MutateDeployedModelResponse", "UndeployModelOperationMetadata", "UndeployModelRequest", "UndeployModelResponse", diff --git a/google/cloud/aiplatform_v1/types/accelerator_type.py b/google/cloud/aiplatform_v1/types/accelerator_type.py index d9d8da0546..41e63804f8 100644 --- a/google/cloud/aiplatform_v1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1/types/accelerator_type.py @@ -47,6 +47,8 @@ class AcceleratorType(proto.Enum): Nvidia Tesla T4 GPU. NVIDIA_TESLA_A100 (8): Nvidia Tesla A100 GPU. + NVIDIA_L4 (11): + Nvidia L4 GPU. TPU_V2 (6): TPU v2. TPU_V3 (7): @@ -61,6 +63,7 @@ class AcceleratorType(proto.Enum): NVIDIA_TESLA_P4 = 4 NVIDIA_TESLA_T4 = 5 NVIDIA_TESLA_A100 = 8 + NVIDIA_L4 = 11 TPU_V2 = 6 TPU_V3 = 7 TPU_V4_POD = 10 diff --git a/google/cloud/aiplatform_v1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1/types/batch_prediction_job.py index 61fe141c9c..803e3f4c04 100644 --- a/google/cloud/aiplatform_v1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1/types/batch_prediction_job.py @@ -226,10 +226,10 @@ class BatchPredictionJob(proto.Message): disable_container_logging (bool): For custom-trained Models and AutoML Tabular Models, the container of the DeployedModel instances will send - ``stderr`` and ``stdout`` streams to Stackdriver Logging by + ``stderr`` and ``stdout`` streams to Cloud Logging by default. Please note that the logs incur cost, which are subject to `Cloud Logging - pricing `__. + pricing `__. User can disable container logging by setting this flag to true. diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py index 98504c7d83..62619846ea 100644 --- a/google/cloud/aiplatform_v1/types/custom_job.py +++ b/google/cloud/aiplatform_v1/types/custom_job.py @@ -268,6 +268,13 @@ class CustomJobSpec(proto.Message): [Trial.web_access_uris][google.cloud.aiplatform.v1.Trial.web_access_uris] (within [HyperparameterTuningJob.trials][google.cloud.aiplatform.v1.HyperparameterTuningJob.trials]). + experiment (str): + Optional. The Experiment associated with this job. Format: + ``projects/{project}/locations/{location}/metadataStores/{metadataStores}/contexts/{experiment-name}`` + experiment_run (str): + Optional. The Experiment Run associated with this job. + Format: + ``projects/{project}/locations/{location}/metadataStores/{metadataStores}/contexts/{experiment-name}-{experiment-run-name}`` """ worker_pool_specs: MutableSequence["WorkerPoolSpec"] = proto.RepeatedField( @@ -309,6 +316,14 @@ class CustomJobSpec(proto.Message): proto.BOOL, number=16, ) + experiment: str = proto.Field( + proto.STRING, + number=17, + ) + experiment_run: str = proto.Field( + proto.STRING, + number=18, + ) class WorkerPoolSpec(proto.Message): diff --git a/google/cloud/aiplatform_v1/types/dataset.py b/google/cloud/aiplatform_v1/types/dataset.py index fb6a08776c..f30cbe582c 100644 --- a/google/cloud/aiplatform_v1/types/dataset.py +++ b/google/cloud/aiplatform_v1/types/dataset.py @@ -91,8 +91,7 @@ class Dataset(proto.Message): title. saved_queries (MutableSequence[google.cloud.aiplatform_v1.types.SavedQuery]): All SavedQueries belong to the Dataset will be returned in - List/Get Dataset response. The - [annotation_specs][SavedQuery.annotation_specs] field will + List/Get Dataset response. The annotation_specs field will not be populated except for UI cases which will only use [annotation_spec_count][google.cloud.aiplatform.v1.SavedQuery.annotation_spec_count]. In CreateDataset request, a SavedQuery is created together @@ -266,10 +265,9 @@ class ExportDataConfig(proto.Message): This field is a member of `oneof`_ ``split``. annotations_filter (str): - A filter on Annotations of the Dataset. Only Annotations on - to-be-exported DataItems(specified by [data_items_filter][]) - that match this filter will be exported. The filter syntax - is the same as in + An expression for filtering what part of the Dataset is to + be exported. Only Annotations that match this filter will be + exported. The filter syntax is the same as in [ListAnnotations][google.cloud.aiplatform.v1.DatasetService.ListAnnotations]. """ diff --git a/google/cloud/aiplatform_v1/types/endpoint.py b/google/cloud/aiplatform_v1/types/endpoint.py index 34d2c26420..6eacc2d3dd 100644 --- a/google/cloud/aiplatform_v1/types/endpoint.py +++ b/google/cloud/aiplatform_v1/types/endpoint.py @@ -122,7 +122,8 @@ class Endpoint(proto.Message): model_deployment_monitoring_job (str): Output only. Resource name of the Model Monitoring job associated with this Endpoint if monitoring is enabled by - [CreateModelDeploymentMonitoringJob][]. Format: + [JobService.CreateModelDeploymentMonitoringJob][google.cloud.aiplatform.v1.JobService.CreateModelDeploymentMonitoringJob]. + Format: ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` predict_request_response_logging_config (google.cloud.aiplatform_v1.types.PredictRequestResponseLoggingConfig): Configures the request-response logging for @@ -278,24 +279,23 @@ class DeployedModel(proto.Message): disable_container_logging (bool): For custom-trained Models and AutoML Tabular Models, the container of the DeployedModel instances will send - ``stderr`` and ``stdout`` streams to Stackdriver Logging by + ``stderr`` and ``stdout`` streams to Cloud Logging by default. Please note that the logs incur cost, which are subject to `Cloud Logging - pricing `__. + pricing `__. User can disable container logging by setting this flag to true. enable_access_logging (bool): If true, online prediction access logs are - sent to StackDriver Logging. + sent to Cloud Logging. These logs are like standard server access logs, containing information like timestamp and latency for each prediction request. - Note that Stackdriver logs may incur a cost, - especially if your project receives prediction - requests at a high queries per second rate - (QPS). Estimate your costs before enabling this - option. + Note that logs may incur a cost, especially if + your project receives prediction requests at a + high queries per second rate (QPS). Estimate + your costs before enabling this option. private_endpoints (google.cloud.aiplatform_v1.types.PrivateEndpoints): Output only. Provide paths for users to send predict/explain/health requests directly to the deployed diff --git a/google/cloud/aiplatform_v1/types/endpoint_service.py b/google/cloud/aiplatform_v1/types/endpoint_service.py index 3abfeff411..128ee99ce8 100644 --- a/google/cloud/aiplatform_v1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1/types/endpoint_service.py @@ -40,6 +40,9 @@ "UndeployModelRequest", "UndeployModelResponse", "UndeployModelOperationMetadata", + "MutateDeployedModelRequest", + "MutateDeployedModelResponse", + "MutateDeployedModelOperationMetadata", }, ) @@ -417,4 +420,81 @@ class UndeployModelOperationMetadata(proto.Message): ) +class MutateDeployedModelRequest(proto.Message): + r"""Request message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel]. + + Attributes: + endpoint (str): + Required. The name of the Endpoint resource into which to + mutate a DeployedModel. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + deployed_model (google.cloud.aiplatform_v1.types.DeployedModel): + Required. The DeployedModel to be mutated within the + Endpoint. Only the following fields can be mutated: + + - ``min_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1.AutomaticResources] + - ``max_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1.AutomaticResources] + - [autoscaling_metric_specs][google.cloud.aiplatform.v1.DedicatedResources.autoscaling_metric_specs] + - ``disable_container_logging`` (v1 only) + - ``enable_container_logging`` (v1beta1 only) + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + """ + + endpoint: str = proto.Field( + proto.STRING, + number=1, + ) + deployed_model: gca_endpoint.DeployedModel = proto.Field( + proto.MESSAGE, + number=2, + message=gca_endpoint.DeployedModel, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=4, + message=field_mask_pb2.FieldMask, + ) + + +class MutateDeployedModelResponse(proto.Message): + r"""Response message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel]. + + Attributes: + deployed_model (google.cloud.aiplatform_v1.types.DeployedModel): + The DeployedModel that's being mutated. + """ + + deployed_model: gca_endpoint.DeployedModel = proto.Field( + proto.MESSAGE, + number=1, + message=gca_endpoint.DeployedModel, + ) + + +class MutateDeployedModelOperationMetadata(proto.Message): + r"""Runtime operation information for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel]. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata: operation.GenericOperationMetadata = proto.Field( + proto.MESSAGE, + number=1, + message=operation.GenericOperationMetadata, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/entity_type.py b/google/cloud/aiplatform_v1/types/entity_type.py index dbfe1b3738..7c9394543e 100644 --- a/google/cloud/aiplatform_v1/types/entity_type.py +++ b/google/cloud/aiplatform_v1/types/entity_type.py @@ -84,6 +84,14 @@ class EntityType(proto.Message): [FeaturestoreMonitoringConfig.monitoring_interval] specified, snapshot analysis monitoring is enabled. Otherwise, snapshot analysis monitoring is disabled. + offline_storage_ttl_days (int): + Optional. Config for data retention policy in offline + storage. TTL in days for feature values that will be stored + in offline storage. The Feature Store offline storage + periodically removes obsolete feature values older than + ``offline_storage_ttl_days`` since the feature generation + time. If unset (or explicitly set to 0), default to 4000 + days TTL. """ name: str = proto.Field( @@ -120,6 +128,10 @@ class EntityType(proto.Message): message=featurestore_monitoring.FeaturestoreMonitoringConfig, ) ) + offline_storage_ttl_days: int = proto.Field( + proto.INT32, + number=10, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/evaluated_annotation.py b/google/cloud/aiplatform_v1/types/evaluated_annotation.py index 830d7eca51..58707978c5 100644 --- a/google/cloud/aiplatform_v1/types/evaluated_annotation.py +++ b/google/cloud/aiplatform_v1/types/evaluated_annotation.py @@ -89,10 +89,6 @@ class EvaluatedAnnotation(proto.Message): ancestor ModelEvaluation. The EvaluatedDataItemView consists of all ground truths and predictions on [data_item_payload][google.cloud.aiplatform.v1.EvaluatedAnnotation.data_item_payload]. - - Can be passed in - [GetEvaluatedDataItemView's][ModelService.GetEvaluatedDataItemView][] - [id][GetEvaluatedDataItemViewRequest.id]. explanations (MutableSequence[google.cloud.aiplatform_v1.types.EvaluatedAnnotationExplanation]): Explanations of [predictions][google.cloud.aiplatform.v1.EvaluatedAnnotation.predictions]. diff --git a/google/cloud/aiplatform_v1/types/explanation.py b/google/cloud/aiplatform_v1/types/explanation.py index 7a79588727..12250ae98f 100644 --- a/google/cloud/aiplatform_v1/types/explanation.py +++ b/google/cloud/aiplatform_v1/types/explanation.py @@ -686,10 +686,9 @@ class ExplanationSpecOverride(proto.Message): Attributes: parameters (google.cloud.aiplatform_v1.types.ExplanationParameters): - The parameters to be overridden. Note that the - [method][google.cloud.aiplatform.v1.ExplanationParameters.method] - cannot be changed. If not specified, no parameter is - overridden. + The parameters to be overridden. Note that + the attribution method cannot be changed. If not + specified, no parameter is overridden. metadata (google.cloud.aiplatform_v1.types.ExplanationMetadataOverride): The metadata to be overridden. If not specified, no metadata is overridden. diff --git a/google/cloud/aiplatform_v1/types/feature.py b/google/cloud/aiplatform_v1/types/feature.py index 9b94f786e2..4bfb54cc53 100644 --- a/google/cloud/aiplatform_v1/types/feature.py +++ b/google/cloud/aiplatform_v1/types/feature.py @@ -125,11 +125,11 @@ class ValueType(proto.Enum): BYTES = 13 class MonitoringStatsAnomaly(proto.Message): - r"""A list of historical [Snapshot - Analysis][FeaturestoreMonitoringConfig.SnapshotAnalysis] or [Import - Feature Analysis] - [FeaturestoreMonitoringConfig.ImportFeatureAnalysis] stats requested - by user, sorted by + r"""A list of historical + [SnapshotAnalysis][google.cloud.aiplatform.v1.FeaturestoreMonitoringConfig.SnapshotAnalysis] + or + [ImportFeaturesAnalysis][google.cloud.aiplatform.v1.FeaturestoreMonitoringConfig.ImportFeaturesAnalysis] + stats requested by user, sorted by [FeatureStatsAnomaly.start_time][google.cloud.aiplatform.v1.FeatureStatsAnomaly.start_time] descending. diff --git a/google/cloud/aiplatform_v1/types/featurestore.py b/google/cloud/aiplatform_v1/types/featurestore.py index 92ffe88677..1bae61b56b 100644 --- a/google/cloud/aiplatform_v1/types/featurestore.py +++ b/google/cloud/aiplatform_v1/types/featurestore.py @@ -74,6 +74,15 @@ class Featurestore(proto.Message): serving. state (google.cloud.aiplatform_v1.types.Featurestore.State): Output only. State of the featurestore. + online_storage_ttl_days (int): + Optional. TTL in days for feature values that will be stored + in online serving storage. The Feature Store online storage + periodically removes obsolete feature values older than + ``online_storage_ttl_days`` since the feature generation + time. Note that ``online_storage_ttl_days`` should be less + than or equal to ``offline_storage_ttl_days`` for each + EntityType under a featurestore. If not set, default to 4000 + days encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): Optional. Customer-managed encryption key spec for data storage. If set, both of the @@ -211,6 +220,10 @@ class Scaling(proto.Message): number=8, enum=State, ) + online_storage_ttl_days: int = proto.Field( + proto.INT32, + number=13, + ) encryption_spec: gca_encryption_spec.EncryptionSpec = proto.Field( proto.MESSAGE, number=10, diff --git a/google/cloud/aiplatform_v1/types/featurestore_online_service.py b/google/cloud/aiplatform_v1/types/featurestore_online_service.py index 100a4e08f9..0905551fff 100644 --- a/google/cloud/aiplatform_v1/types/featurestore_online_service.py +++ b/google/cloud/aiplatform_v1/types/featurestore_online_service.py @@ -175,7 +175,7 @@ class Header(proto.Message): ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}``. feature_descriptors (MutableSequence[google.cloud.aiplatform_v1.types.ReadFeatureValuesResponse.FeatureDescriptor]): List of Feature metadata corresponding to each piece of - [ReadFeatureValuesResponse.data][]. + [ReadFeatureValuesResponse.EntityView.data][google.cloud.aiplatform.v1.ReadFeatureValuesResponse.EntityView.data]. """ entity_type: str = proto.Field( diff --git a/google/cloud/aiplatform_v1/types/featurestore_service.py b/google/cloud/aiplatform_v1/types/featurestore_service.py index 0f60deadcc..0d93f413fe 100644 --- a/google/cloud/aiplatform_v1/types/featurestore_service.py +++ b/google/cloud/aiplatform_v1/types/featurestore_service.py @@ -268,7 +268,7 @@ class UpdateFeaturestoreRequest(proto.Message): - ``labels`` - ``online_serving_config.fixed_node_count`` - ``online_serving_config.scaling`` - - ``online_storage_ttl_days`` (available in Preview) + - ``online_storage_ttl_days`` """ featurestore: gca_featurestore.Featurestore = proto.Field( @@ -1074,7 +1074,7 @@ class UpdateEntityTypeRequest(proto.Message): - ``monitoring_config.import_features_analysis.anomaly_detection_baseline`` - ``monitoring_config.numerical_threshold_config.value`` - ``monitoring_config.categorical_threshold_config.value`` - - ``offline_storage_ttl_days`` (available in Preview) + - ``offline_storage_ttl_days`` """ entity_type: gca_entity_type.EntityType = proto.Field( diff --git a/google/cloud/aiplatform_v1/types/index_endpoint.py b/google/cloud/aiplatform_v1/types/index_endpoint.py index 232a859c1f..d9e43da6c7 100644 --- a/google/cloud/aiplatform_v1/types/index_endpoint.py +++ b/google/cloud/aiplatform_v1/types/index_endpoint.py @@ -259,15 +259,15 @@ class DeployedIndex(proto.Message): efficiency. enable_access_logging (bool): Optional. If true, private endpoint's access - logs are sent to StackDriver Logging. + logs are sent to Cloud Logging. These logs are like standard server access logs, containing information like timestamp and latency for each MatchRequest. - Note that Stackdriver logs may incur a cost, - especially if the deployed index receives a high - queries per second rate (QPS). Estimate your - costs before enabling this option. + Note that logs may incur a cost, especially if + the deployed index receives a high queries per + second rate (QPS). Estimate your costs before + enabling this option. deployed_index_auth_config (google.cloud.aiplatform_v1.types.DeployedIndexAuthConfig): Optional. If set, the authentication is enabled for the private endpoint. diff --git a/google/cloud/aiplatform_v1/types/model.py b/google/cloud/aiplatform_v1/types/model.py index d55d87f006..dfac8c79fc 100644 --- a/google/cloud/aiplatform_v1/types/model.py +++ b/google/cloud/aiplatform_v1/types/model.py @@ -923,12 +923,15 @@ class ModelSourceType(proto.Enum): MODEL_GARDEN (4): The Model is saved or tuned from Model Garden. + GENIE (5): + The Model is saved or tuned from Genie. """ MODEL_SOURCE_TYPE_UNSPECIFIED = 0 AUTOML = 1 CUSTOM = 2 BQML = 3 MODEL_GARDEN = 4 + GENIE = 5 source_type: ModelSourceType = proto.Field( proto.ENUM, diff --git a/google/cloud/aiplatform_v1/types/model_service.py b/google/cloud/aiplatform_v1/types/model_service.py index f734c409f1..580eab5d29 100644 --- a/google/cloud/aiplatform_v1/types/model_service.py +++ b/google/cloud/aiplatform_v1/types/model_service.py @@ -307,8 +307,10 @@ class ListModelVersionsRequest(proto.Message): The standard list page size. page_token (str): The standard list page token. Typically obtained via - [ListModelVersionsResponse.next_page_token][google.cloud.aiplatform.v1.ListModelVersionsResponse.next_page_token] - of the previous [ModelService.ListModelversions][] call. + [next_page_token][google.cloud.aiplatform.v1.ListModelVersionsResponse.next_page_token] + of the previous + [ListModelVersions][google.cloud.aiplatform.v1.ModelService.ListModelVersions] + call. filter (str): An expression for filtering the results of the request. For field names both snake_case and camelCase are supported. diff --git a/google/cloud/aiplatform_v1/types/pipeline_job.py b/google/cloud/aiplatform_v1/types/pipeline_job.py index cd26661907..0daea846e0 100644 --- a/google/cloud/aiplatform_v1/types/pipeline_job.py +++ b/google/cloud/aiplatform_v1/types/pipeline_job.py @@ -374,7 +374,8 @@ class PipelineTaskDetail(proto.Message): task is at the root level. task_name (str): Output only. The user specified name of the task that is - defined in [PipelineJob.spec][]. + defined in + [pipeline_spec][google.cloud.aiplatform.v1.PipelineJob.pipeline_spec]. create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Task create time. start_time (google.protobuf.timestamp_pb2.Timestamp): diff --git a/google/cloud/aiplatform_v1/types/tensorboard_experiment.py b/google/cloud/aiplatform_v1/types/tensorboard_experiment.py index 20ee6abc85..3b54244066 100644 --- a/google/cloud/aiplatform_v1/types/tensorboard_experiment.py +++ b/google/cloud/aiplatform_v1/types/tensorboard_experiment.py @@ -54,7 +54,7 @@ class TensorboardExperiment(proto.Message): The labels with user-defined metadata to organize your Datasets. - Label keys and values can be no longer than 64 characters + Label keys and values cannot be longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. No more than 64 user labels can be @@ -62,13 +62,13 @@ class TensorboardExperiment(proto.Message): See https://goo.gl/xmQnxf for more information and examples of labels. System reserved label keys are prefixed with - "aiplatform.googleapis.com/" and are immutable. Following - system labels exist for each Dataset: + ``aiplatform.googleapis.com/`` and are immutable. The + following system labels exist for each Dataset: - - "aiplatform.googleapis.com/dataset_metadata_schema": - - - output only, its value is the - [metadata_schema's][metadata_schema_uri] title. + - ``aiplatform.googleapis.com/dataset_metadata_schema``: + output only. Its value is the + [metadata_schema's][google.cloud.aiplatform.v1.Dataset.metadata_schema_uri] + title. etag (str): Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update diff --git a/google/cloud/aiplatform_v1/types/tensorboard_service.py b/google/cloud/aiplatform_v1/types/tensorboard_service.py index a159126822..2eb77544de 100644 --- a/google/cloud/aiplatform_v1/types/tensorboard_service.py +++ b/google/cloud/aiplatform_v1/types/tensorboard_service.py @@ -1226,12 +1226,12 @@ class ExportTensorboardTimeSeriesDataRequest(proto.Message): 10000. Values above 10000 are coerced to 10000. page_token (str): A page token, received from a previous - [TensorboardService.ExportTensorboardTimeSeries][] call. - Provide this to retrieve the subsequent page. + [ExportTensorboardTimeSeriesData][google.cloud.aiplatform.v1.TensorboardService.ExportTensorboardTimeSeriesData] + call. Provide this to retrieve the subsequent page. When paginating, all other parameters provided to - [TensorboardService.ExportTensorboardTimeSeries][] must - match the call that provided the page token. + [ExportTensorboardTimeSeriesData][google.cloud.aiplatform.v1.TensorboardService.ExportTensorboardTimeSeriesData] + must match the call that provided the page token. order_by (str): Field to use to sort the TensorboardTimeSeries' data. By default, @@ -1270,9 +1270,9 @@ class ExportTensorboardTimeSeriesDataResponse(proto.Message): The returned time series data points. next_page_token (str): A token, which can be sent as - [ExportTensorboardTimeSeriesRequest.page_token][] to - retrieve the next page. If this field is omitted, there are - no subsequent pages. + [page_token][google.cloud.aiplatform.v1.ExportTensorboardTimeSeriesDataRequest.page_token] + to retrieve the next page. If this field is omitted, there + are no subsequent pages. """ @property diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index dec5937f59..2583dc3c63 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -48,6 +48,8 @@ from .services.metadata_service import MetadataServiceAsyncClient from .services.migration_service import MigrationServiceClient from .services.migration_service import MigrationServiceAsyncClient +from .services.model_garden_service import ModelGardenServiceClient +from .services.model_garden_service import ModelGardenServiceAsyncClient from .services.model_service import ModelServiceClient from .services.model_service import ModelServiceAsyncClient from .services.pipeline_service import PipelineServiceClient @@ -138,6 +140,9 @@ from .types.endpoint_service import GetEndpointRequest from .types.endpoint_service import ListEndpointsRequest from .types.endpoint_service import ListEndpointsResponse +from .types.endpoint_service import MutateDeployedModelOperationMetadata +from .types.endpoint_service import MutateDeployedModelRequest +from .types.endpoint_service import MutateDeployedModelResponse from .types.endpoint_service import UndeployModelOperationMetadata from .types.endpoint_service import UndeployModelRequest from .types.endpoint_service import UndeployModelResponse @@ -388,6 +393,7 @@ from .types.migration_service import MigrateResourceResponse from .types.migration_service import SearchMigratableResourcesRequest from .types.migration_service import SearchMigratableResourcesResponse +from .types.model import LargeModelReference from .types.model import Model from .types.model import ModelContainerSpec from .types.model import ModelSourceInfo @@ -409,6 +415,8 @@ ) from .types.model_evaluation import ModelEvaluation from .types.model_evaluation_slice import ModelEvaluationSlice +from .types.model_garden_service import GetPublisherModelRequest +from .types.model_garden_service import PublisherModelView from .types.model_monitoring import ModelMonitoringAlertConfig from .types.model_monitoring import ModelMonitoringConfig from .types.model_monitoring import ModelMonitoringObjectiveConfig @@ -477,6 +485,7 @@ from .types.prediction_service import PredictRequest from .types.prediction_service import PredictResponse from .types.prediction_service import RawPredictRequest +from .types.publisher_model import PublisherModel from .types.saved_query import SavedQuery from .types.schedule import Schedule from .types.schedule_service import CreateScheduleRequest @@ -603,6 +612,7 @@ "MatchServiceAsyncClient", "MetadataServiceAsyncClient", "MigrationServiceAsyncClient", + "ModelGardenServiceAsyncClient", "ModelServiceAsyncClient", "PipelineServiceAsyncClient", "PredictionServiceAsyncClient", @@ -848,6 +858,7 @@ "GetNasJobRequest", "GetNasTrialDetailRequest", "GetPipelineJobRequest", + "GetPublisherModelRequest", "GetScheduleRequest", "GetSpecialistPoolRequest", "GetStudyRequest", @@ -879,6 +890,7 @@ "IntegratedGradientsAttribution", "JobServiceClient", "JobState", + "LargeModelReference", "LineageSubgraph", "ListAnnotationsRequest", "ListAnnotationsResponse", @@ -979,6 +991,7 @@ "ModelEvaluation", "ModelEvaluationSlice", "ModelExplanation", + "ModelGardenServiceClient", "ModelMonitoringAlertConfig", "ModelMonitoringConfig", "ModelMonitoringObjectiveConfig", @@ -988,6 +1001,9 @@ "MutateDeployedIndexOperationMetadata", "MutateDeployedIndexRequest", "MutateDeployedIndexResponse", + "MutateDeployedModelOperationMetadata", + "MutateDeployedModelRequest", + "MutateDeployedModelResponse", "NasJob", "NasJobOutput", "NasJobSpec", @@ -1016,6 +1032,8 @@ "Presets", "PrivateEndpoints", "PrivateServiceConnectConfig", + "PublisherModel", + "PublisherModelView", "PurgeArtifactsMetadata", "PurgeArtifactsRequest", "PurgeArtifactsResponse", diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json index 37fe5a28f1..fe65cacbd1 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json @@ -233,6 +233,11 @@ "list_endpoints" ] }, + "MutateDeployedModel": { + "methods": [ + "mutate_deployed_model" + ] + }, "UndeployModel": { "methods": [ "undeploy_model" @@ -273,6 +278,11 @@ "list_endpoints" ] }, + "MutateDeployedModel": { + "methods": [ + "mutate_deployed_model" + ] + }, "UndeployModel": { "methods": [ "undeploy_model" @@ -1499,6 +1509,30 @@ } } }, + "ModelGardenService": { + "clients": { + "grpc": { + "libraryClient": "ModelGardenServiceClient", + "rpcs": { + "GetPublisherModel": { + "methods": [ + "get_publisher_model" + ] + } + } + }, + "grpc-async": { + "libraryClient": "ModelGardenServiceAsyncClient", + "rpcs": { + "GetPublisherModel": { + "methods": [ + "get_publisher_model" + ] + } + } + } + } + }, "ModelService": { "clients": { "grpc": { diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index e1251ccfff..085d0e2f28 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.24.1" # {x-release-please-version} +__version__ = "1.25.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 0ebf7b7d92..15983ccbc2 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -1168,6 +1168,165 @@ async def sample_undeploy_model(): # Done; return the response. return response + async def mutate_deployed_model( + self, + request: Optional[ + Union[endpoint_service.MutateDeployedModelRequest, dict] + ] = None, + *, + endpoint: Optional[str] = None, + deployed_model: Optional[gca_endpoint.DeployedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates an existing deployed model. Updatable fields include + ``min_replica_count``, ``max_replica_count``, + ``autoscaling_metric_specs``, ``disable_container_logging`` (v1 + only), and ``enable_container_logging`` (v1beta1 only). + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + async def sample_mutate_deployed_model(): + # Create a client + client = aiplatform_v1beta1.EndpointServiceAsyncClient() + + # Initialize request argument(s) + deployed_model = aiplatform_v1beta1.DeployedModel() + deployed_model.dedicated_resources.min_replica_count = 1803 + deployed_model.model = "model_value" + + request = aiplatform_v1beta1.MutateDeployedModelRequest( + endpoint="endpoint_value", + deployed_model=deployed_model, + ) + + # Make the request + operation = client.mutate_deployed_model(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1beta1.types.MutateDeployedModelRequest, dict]]): + The request object. Request message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel]. + endpoint (:class:`str`): + Required. The name of the Endpoint resource into which + to mutate a DeployedModel. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model (:class:`google.cloud.aiplatform_v1beta1.types.DeployedModel`): + Required. The DeployedModel to be mutated within the + Endpoint. Only the following fields can be mutated: + + - ``min_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1beta1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1beta1.AutomaticResources] + - ``max_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1beta1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1beta1.AutomaticResources] + - [autoscaling_metric_specs][google.cloud.aiplatform.v1beta1.DedicatedResources.autoscaling_metric_specs] + - ``disable_container_logging`` (v1 only) + - ``enable_container_logging`` (v1beta1 only) + + This corresponds to the ``deployed_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.MutateDeployedModelResponse` Response message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel]. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.MutateDeployedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.mutate_deployed_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + endpoint_service.MutateDeployedModelResponse, + metadata_type=endpoint_service.MutateDeployedModelOperationMetadata, + ) + + # Done; return the response. + return response + async def list_operations( self, request: Optional[operations_pb2.ListOperationsRequest] = None, diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 46a2ebf223..e91b4381b3 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -1462,6 +1462,165 @@ def sample_undeploy_model(): # Done; return the response. return response + def mutate_deployed_model( + self, + request: Optional[ + Union[endpoint_service.MutateDeployedModelRequest, dict] + ] = None, + *, + endpoint: Optional[str] = None, + deployed_model: Optional[gca_endpoint.DeployedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Updates an existing deployed model. Updatable fields include + ``min_replica_count``, ``max_replica_count``, + ``autoscaling_metric_specs``, ``disable_container_logging`` (v1 + only), and ``enable_container_logging`` (v1beta1 only). + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_mutate_deployed_model(): + # Create a client + client = aiplatform_v1beta1.EndpointServiceClient() + + # Initialize request argument(s) + deployed_model = aiplatform_v1beta1.DeployedModel() + deployed_model.dedicated_resources.min_replica_count = 1803 + deployed_model.model = "model_value" + + request = aiplatform_v1beta1.MutateDeployedModelRequest( + endpoint="endpoint_value", + deployed_model=deployed_model, + ) + + # Make the request + operation = client.mutate_deployed_model(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.MutateDeployedModelRequest, dict]): + The request object. Request message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel]. + endpoint (str): + Required. The name of the Endpoint resource into which + to mutate a DeployedModel. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model (google.cloud.aiplatform_v1beta1.types.DeployedModel): + Required. The DeployedModel to be mutated within the + Endpoint. Only the following fields can be mutated: + + - ``min_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1beta1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1beta1.AutomaticResources] + - ``max_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1beta1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1beta1.AutomaticResources] + - [autoscaling_metric_specs][google.cloud.aiplatform.v1beta1.DedicatedResources.autoscaling_metric_specs] + - ``disable_container_logging`` (v1 only) + - ``enable_container_logging`` (v1beta1 only) + + This corresponds to the ``deployed_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.MutateDeployedModelResponse` Response message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel]. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.MutateDeployedModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.MutateDeployedModelRequest): + request = endpoint_service.MutateDeployedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.mutate_deployed_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + endpoint_service.MutateDeployedModelResponse, + metadata_type=endpoint_service.MutateDeployedModelOperationMetadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "EndpointServiceClient": return self diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index a22eb5ef9a..bdc7682030 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -166,6 +166,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=5.0, client_info=client_info, ), + self.mutate_deployed_model: gapic_v1.method.wrap_method( + self.mutate_deployed_model, + default_timeout=None, + client_info=client_info, + ), } def close(self): @@ -248,6 +253,15 @@ def undeploy_model( ]: raise NotImplementedError() + @property + def mutate_deployed_model( + self, + ) -> Callable[ + [endpoint_service.MutateDeployedModelRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def list_operations( self, diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 1cd4d9c5c7..535b668105 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -439,6 +439,37 @@ def undeploy_model( ) return self._stubs["undeploy_model"] + @property + def mutate_deployed_model( + self, + ) -> Callable[ + [endpoint_service.MutateDeployedModelRequest], operations_pb2.Operation + ]: + r"""Return a callable for the mutate deployed model method over gRPC. + + Updates an existing deployed model. Updatable fields include + ``min_replica_count``, ``max_replica_count``, + ``autoscaling_metric_specs``, ``disable_container_logging`` (v1 + only), and ``enable_container_logging`` (v1beta1 only). + + Returns: + Callable[[~.MutateDeployedModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "mutate_deployed_model" not in self._stubs: + self._stubs["mutate_deployed_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/MutateDeployedModel", + request_serializer=endpoint_service.MutateDeployedModelRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["mutate_deployed_model"] + def close(self): self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index 74caf183de..ae0c77a093 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -455,6 +455,38 @@ def undeploy_model( ) return self._stubs["undeploy_model"] + @property + def mutate_deployed_model( + self, + ) -> Callable[ + [endpoint_service.MutateDeployedModelRequest], + Awaitable[operations_pb2.Operation], + ]: + r"""Return a callable for the mutate deployed model method over gRPC. + + Updates an existing deployed model. Updatable fields include + ``min_replica_count``, ``max_replica_count``, + ``autoscaling_metric_specs``, ``disable_container_logging`` (v1 + only), and ``enable_container_logging`` (v1beta1 only). + + Returns: + Callable[[~.MutateDeployedModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "mutate_deployed_model" not in self._stubs: + self._stubs["mutate_deployed_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/MutateDeployedModel", + request_serializer=endpoint_service.MutateDeployedModelRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["mutate_deployed_model"] + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py index 2610f51a6c..23520be59d 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py @@ -687,7 +687,7 @@ async def sample_update_featurestore(): - ``labels`` - ``online_serving_config.fixed_node_count`` - ``online_serving_config.scaling`` - - ``online_storage_ttl_days`` (available in Preview) + - ``online_storage_ttl_days`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this @@ -1349,7 +1349,7 @@ async def sample_update_entity_type(): - ``monitoring_config.import_features_analysis.anomaly_detection_baseline`` - ``monitoring_config.numerical_threshold_config.value`` - ``monitoring_config.categorical_threshold_config.value`` - - ``offline_storage_ttl_days`` (available in Preview) + - ``offline_storage_ttl_days`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py index a0410f432f..a54f2a4ca3 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -955,7 +955,7 @@ def sample_update_featurestore(): - ``labels`` - ``online_serving_config.fixed_node_count`` - ``online_serving_config.scaling`` - - ``online_storage_ttl_days`` (available in Preview) + - ``online_storage_ttl_days`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this @@ -1617,7 +1617,7 @@ def sample_update_entity_type(): - ``monitoring_config.import_features_analysis.anomaly_detection_baseline`` - ``monitoring_config.numerical_threshold_config.value`` - ``monitoring_config.categorical_threshold_config.value`` - - ``offline_storage_ttl_days`` (available in Preview) + - ``offline_storage_ttl_days`` This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index f6ca7d5f68..16ecda161b 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -105,6 +105,8 @@ class JobServiceAsyncClient: parse_batch_prediction_job_path = staticmethod( JobServiceClient.parse_batch_prediction_job_path ) + context_path = staticmethod(JobServiceClient.context_path) + parse_context_path = staticmethod(JobServiceClient.parse_context_path) custom_job_path = staticmethod(JobServiceClient.custom_job_path) parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index d103372c51..2ee69e44c2 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -238,6 +238,30 @@ def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def context_path( + project: str, + location: str, + metadata_store: str, + context: str, + ) -> str: + """Returns a fully-qualified context string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + + @staticmethod + def parse_context_path(path: str) -> Dict[str, str]: + """Parses a context path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def custom_job_path( project: str, diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/__init__.py new file mode 100644 index 0000000000..bf7ee03b38 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import ModelGardenServiceClient +from .async_client import ModelGardenServiceAsyncClient + +__all__ = ( + "ModelGardenServiceClient", + "ModelGardenServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py new file mode 100644 index 0000000000..f433dde2ee --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py @@ -0,0 +1,1024 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.cloud.aiplatform_v1beta1 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model_garden_service +from google.cloud.aiplatform_v1beta1.types import publisher_model +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from .transports.base import ModelGardenServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import ModelGardenServiceGrpcAsyncIOTransport +from .client import ModelGardenServiceClient + + +class ModelGardenServiceAsyncClient: + """The interface of Model Garden Service.""" + + _client: ModelGardenServiceClient + + DEFAULT_ENDPOINT = ModelGardenServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = ModelGardenServiceClient.DEFAULT_MTLS_ENDPOINT + + publisher_model_path = staticmethod(ModelGardenServiceClient.publisher_model_path) + parse_publisher_model_path = staticmethod( + ModelGardenServiceClient.parse_publisher_model_path + ) + common_billing_account_path = staticmethod( + ModelGardenServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ModelGardenServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(ModelGardenServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + ModelGardenServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + ModelGardenServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + ModelGardenServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(ModelGardenServiceClient.common_project_path) + parse_common_project_path = staticmethod( + ModelGardenServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(ModelGardenServiceClient.common_location_path) + parse_common_location_path = staticmethod( + ModelGardenServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelGardenServiceAsyncClient: The constructed client. + """ + return ModelGardenServiceClient.from_service_account_info.__func__(ModelGardenServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelGardenServiceAsyncClient: The constructed client. + """ + return ModelGardenServiceClient.from_service_account_file.__func__(ModelGardenServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return ModelGardenServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> ModelGardenServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ModelGardenServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(ModelGardenServiceClient).get_transport_class, + type(ModelGardenServiceClient), + ) + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, ModelGardenServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the model garden service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.ModelGardenServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = ModelGardenServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def get_publisher_model( + self, + request: Optional[ + Union[model_garden_service.GetPublisherModelRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> publisher_model.PublisherModel: + r"""Gets a Model Garden publisher model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + async def sample_get_publisher_model(): + # Create a client + client = aiplatform_v1beta1.ModelGardenServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.GetPublisherModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_publisher_model(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1beta1.types.GetPublisherModelRequest, dict]]): + The request object. Request message for + [ModelGardenService.GetPublisherModel][google.cloud.aiplatform.v1beta1.ModelGardenService.GetPublisherModel] + name (:class:`str`): + Required. The name of the PublisherModel resource. + Format: + ``publishers/{publisher}/models/{publisher_model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.PublisherModel: + A Model Garden Publisher Model. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_garden_service.GetPublisherModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_publisher_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.list_operations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.get_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_operation( + self, + request: Optional[operations_pb2.DeleteOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a long-running operation. + + This method indicates that the client is no longer interested + in the operation result. It does not cancel the operation. + If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.DeleteOperationRequest`): + The request object. Request message for + `DeleteOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.DeleteOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.delete_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def cancel_operation( + self, + request: Optional[operations_pb2.CancelOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Starts asynchronous cancellation on a long-running operation. + + The server makes a best effort to cancel the operation, but success + is not guaranteed. If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.CancelOperationRequest`): + The request object. Request message for + `CancelOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.CancelOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.cancel_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def wait_operation( + self, + request: Optional[operations_pb2.WaitOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Waits until the specified long-running operation is done or reaches at most + a specified timeout, returning the latest state. + + If the operation is already done, the latest state is immediately returned. + If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC + timeout is used. If the server does not support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.WaitOperationRequest`): + The request object. Request message for + `WaitOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.WaitOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.wait_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def set_iam_policy( + self, + request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Sets the IAM access control policy on the specified function. + + Replaces any existing policy. + + Args: + request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`): + The request object. Request message for `SetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.SetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.set_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_iam_policy( + self, + request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Gets the IAM access control policy for a function. + + Returns an empty policy if the function exists and does not have a + policy set. + + Args: + request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`): + The request object. Request message for `GetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if + any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.GetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.get_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def test_iam_permissions( + self, + request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> iam_policy_pb2.TestIamPermissionsResponse: + r"""Tests the specified IAM permissions against the IAM access control + policy for a function. + + If the function does not exist, this will return an empty set + of permissions, not a NOT_FOUND error. + + Args: + request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`): + The request object. Request message for + `TestIamPermissions` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.iam_policy_pb2.TestIamPermissionsResponse: + Response message for ``TestIamPermissions`` method. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.TestIamPermissionsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.test_iam_permissions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_location( + self, + request: Optional[locations_pb2.GetLocationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.Location: + r"""Gets information about a location. + + Args: + request (:class:`~.location_pb2.GetLocationRequest`): + The request object. Request message for + `GetLocation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.Location: + Location object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.GetLocationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.get_location, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_locations( + self, + request: Optional[locations_pb2.ListLocationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.ListLocationsResponse: + r"""Lists information about the supported locations for this service. + + Args: + request (:class:`~.location_pb2.ListLocationsRequest`): + The request object. Request message for + `ListLocations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.ListLocationsResponse: + Response message for ``ListLocations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.ListLocationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.list_locations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("ModelGardenServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py new file mode 100644 index 0000000000..51e59da381 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py @@ -0,0 +1,1250 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from google.cloud.aiplatform_v1beta1 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model_garden_service +from google.cloud.aiplatform_v1beta1.types import publisher_model +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from .transports.base import ModelGardenServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import ModelGardenServiceGrpcTransport +from .transports.grpc_asyncio import ModelGardenServiceGrpcAsyncIOTransport + + +class ModelGardenServiceClientMeta(type): + """Metaclass for the ModelGardenService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[ModelGardenServiceTransport]] + _transport_registry["grpc"] = ModelGardenServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelGardenServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[ModelGardenServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class ModelGardenServiceClient(metaclass=ModelGardenServiceClientMeta): + """The interface of Model Garden Service.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelGardenServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelGardenServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ModelGardenServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ModelGardenServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def publisher_model_path( + publisher: str, + model: str, + ) -> str: + """Returns a fully-qualified publisher_model string.""" + return "publishers/{publisher}/models/{model}".format( + publisher=publisher, + model=model, + ) + + @staticmethod + def parse_publisher_model_path(path: str) -> Dict[str, str]: + """Parses a publisher_model path into its component segments.""" + m = re.match(r"^publishers/(?P.+?)/models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, ModelGardenServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the model garden service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ModelGardenServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options + ) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, ModelGardenServiceTransport): + # transport is a ModelGardenServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def get_publisher_model( + self, + request: Optional[ + Union[model_garden_service.GetPublisherModelRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> publisher_model.PublisherModel: + r"""Gets a Model Garden publisher model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_get_publisher_model(): + # Create a client + client = aiplatform_v1beta1.ModelGardenServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.GetPublisherModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_publisher_model(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.GetPublisherModelRequest, dict]): + The request object. Request message for + [ModelGardenService.GetPublisherModel][google.cloud.aiplatform.v1beta1.ModelGardenService.GetPublisherModel] + name (str): + Required. The name of the PublisherModel resource. + Format: + ``publishers/{publisher}/models/{publisher_model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.PublisherModel: + A Model Garden Publisher Model. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_garden_service.GetPublisherModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_garden_service.GetPublisherModelRequest): + request = model_garden_service.GetPublisherModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_publisher_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "ModelGardenServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.list_operations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_operation( + self, + request: Optional[operations_pb2.DeleteOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a long-running operation. + + This method indicates that the client is no longer interested + in the operation result. It does not cancel the operation. + If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.DeleteOperationRequest`): + The request object. Request message for + `DeleteOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.DeleteOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.delete_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def cancel_operation( + self, + request: Optional[operations_pb2.CancelOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Starts asynchronous cancellation on a long-running operation. + + The server makes a best effort to cancel the operation, but success + is not guaranteed. If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.CancelOperationRequest`): + The request object. Request message for + `CancelOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.CancelOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.cancel_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def wait_operation( + self, + request: Optional[operations_pb2.WaitOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Waits until the specified long-running operation is done or reaches at most + a specified timeout, returning the latest state. + + If the operation is already done, the latest state is immediately returned. + If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC + timeout is used. If the server does not support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.WaitOperationRequest`): + The request object. Request message for + `WaitOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.WaitOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.wait_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def set_iam_policy( + self, + request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Sets the IAM access control policy on the specified function. + + Replaces any existing policy. + + Args: + request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`): + The request object. Request message for `SetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.SetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.set_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_iam_policy( + self, + request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Gets the IAM access control policy for a function. + + Returns an empty policy if the function exists and does not have a + policy set. + + Args: + request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`): + The request object. Request message for `GetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if + any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.GetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def test_iam_permissions( + self, + request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> iam_policy_pb2.TestIamPermissionsResponse: + r"""Tests the specified IAM permissions against the IAM access control + policy for a function. + + If the function does not exist, this will return an empty set + of permissions, not a NOT_FOUND error. + + Args: + request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`): + The request object. Request message for + `TestIamPermissions` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.iam_policy_pb2.TestIamPermissionsResponse: + Response message for ``TestIamPermissions`` method. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.TestIamPermissionsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.test_iam_permissions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_location( + self, + request: Optional[locations_pb2.GetLocationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.Location: + r"""Gets information about a location. + + Args: + request (:class:`~.location_pb2.GetLocationRequest`): + The request object. Request message for + `GetLocation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.Location: + Location object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.GetLocationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_location, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_locations( + self, + request: Optional[locations_pb2.ListLocationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.ListLocationsResponse: + r"""Lists information about the supported locations for this service. + + Args: + request (:class:`~.location_pb2.ListLocationsRequest`): + The request object. Request message for + `ListLocations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.ListLocationsResponse: + Response message for ``ListLocations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.ListLocationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.list_locations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("ModelGardenServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/__init__.py new file mode 100644 index 0000000000..eb492f9a27 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import ModelGardenServiceTransport +from .grpc import ModelGardenServiceGrpcTransport +from .grpc_asyncio import ModelGardenServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[ModelGardenServiceTransport]] +_transport_registry["grpc"] = ModelGardenServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ModelGardenServiceGrpcAsyncIOTransport + +__all__ = ( + "ModelGardenServiceTransport", + "ModelGardenServiceGrpcTransport", + "ModelGardenServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/base.py new file mode 100644 index 0000000000..93ff7bbec9 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/base.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.cloud.aiplatform_v1beta1 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1beta1.types import model_garden_service +from google.cloud.aiplatform_v1beta1.types import publisher_model +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class ModelGardenServiceTransport(abc.ABC): + """Abstract transport class for ModelGardenService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + DEFAULT_HOST: str = "aiplatform.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.get_publisher_model: gapic_v1.method.wrap_method( + self.get_publisher_model, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def get_publisher_model( + self, + ) -> Callable[ + [model_garden_service.GetPublisherModelRequest], + Union[ + publisher_model.PublisherModel, Awaitable[publisher_model.PublisherModel] + ], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None,]: + raise NotImplementedError() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None,]: + raise NotImplementedError() + + @property + def wait_operation( + self, + ) -> Callable[ + [operations_pb2.WaitOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def set_iam_policy( + self, + ) -> Callable[ + [iam_policy_pb2.SetIamPolicyRequest], + Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]], + ]: + raise NotImplementedError() + + @property + def get_iam_policy( + self, + ) -> Callable[ + [iam_policy_pb2.GetIamPolicyRequest], + Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]], + ]: + raise NotImplementedError() + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + Union[ + iam_policy_pb2.TestIamPermissionsResponse, + Awaitable[iam_policy_pb2.TestIamPermissionsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_location( + self, + ) -> Callable[ + [locations_pb2.GetLocationRequest], + Union[locations_pb2.Location, Awaitable[locations_pb2.Location]], + ]: + raise NotImplementedError() + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], + Union[ + locations_pb2.ListLocationsResponse, + Awaitable[locations_pb2.ListLocationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("ModelGardenServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc.py new file mode 100644 index 0000000000..6addd3d80d --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import model_garden_service +from google.cloud.aiplatform_v1beta1.types import publisher_model +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from .base import ModelGardenServiceTransport, DEFAULT_CLIENT_INFO + + +class ModelGardenServiceGrpcTransport(ModelGardenServiceTransport): + """gRPC backend transport for ModelGardenService. + + The interface of Model Garden Service. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def get_publisher_model( + self, + ) -> Callable[ + [model_garden_service.GetPublisherModelRequest], publisher_model.PublisherModel + ]: + r"""Return a callable for the get publisher model method over gRPC. + + Gets a Model Garden publisher model. + + Returns: + Callable[[~.GetPublisherModelRequest], + ~.PublisherModel]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_publisher_model" not in self._stubs: + self._stubs["get_publisher_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelGardenService/GetPublisherModel", + request_serializer=model_garden_service.GetPublisherModelRequest.serialize, + response_deserializer=publisher_model.PublisherModel.deserialize, + ) + return self._stubs["get_publisher_model"] + + def close(self): + self.grpc_channel.close() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None]: + r"""Return a callable for the delete_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["delete_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/DeleteOperation", + request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["delete_operation"] + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None]: + r"""Return a callable for the cancel_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_operation" not in self._stubs: + self._stubs["cancel_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/CancelOperation", + request_serializer=operations_pb2.CancelOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["cancel_operation"] + + @property + def wait_operation( + self, + ) -> Callable[[operations_pb2.WaitOperationRequest], None]: + r"""Return a callable for the wait_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["wait_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/WaitOperation", + request_serializer=operations_pb2.WaitOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["wait_operation"] + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse + ]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_locations" not in self._stubs: + self._stubs["list_locations"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/ListLocations", + request_serializer=locations_pb2.ListLocationsRequest.SerializeToString, + response_deserializer=locations_pb2.ListLocationsResponse.FromString, + ) + return self._stubs["list_locations"] + + @property + def get_location( + self, + ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_location" not in self._stubs: + self._stubs["get_location"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/GetLocation", + request_serializer=locations_pb2.GetLocationRequest.SerializeToString, + response_deserializer=locations_pb2.Location.FromString, + ) + return self._stubs["get_location"] + + @property + def set_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the set iam policy method over gRPC. + Sets the IAM access control policy on the specified + function. Replaces any existing policy. + Returns: + Callable[[~.SetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "set_iam_policy" not in self._stubs: + self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/SetIamPolicy", + request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["set_iam_policy"] + + @property + def get_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the get iam policy method over gRPC. + Gets the IAM access control policy for a function. + Returns an empty policy if the function exists and does + not have a policy set. + Returns: + Callable[[~.GetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_iam_policy" not in self._stubs: + self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/GetIamPolicy", + request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["get_iam_policy"] + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + iam_policy_pb2.TestIamPermissionsResponse, + ]: + r"""Return a callable for the test iam permissions method over gRPC. + Tests the specified permissions against the IAM access control + policy for a function. If the function does not exist, this will + return an empty set of permissions, not a NOT_FOUND error. + Returns: + Callable[[~.TestIamPermissionsRequest], + ~.TestIamPermissionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "test_iam_permissions" not in self._stubs: + self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/TestIamPermissions", + request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString, + response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString, + ) + return self._stubs["test_iam_permissions"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("ModelGardenServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..c9f783d407 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc_asyncio.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import model_garden_service +from google.cloud.aiplatform_v1beta1.types import publisher_model +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from .base import ModelGardenServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import ModelGardenServiceGrpcTransport + + +class ModelGardenServiceGrpcAsyncIOTransport(ModelGardenServiceTransport): + """gRPC AsyncIO backend transport for ModelGardenService. + + The interface of Model Garden Service. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def get_publisher_model( + self, + ) -> Callable[ + [model_garden_service.GetPublisherModelRequest], + Awaitable[publisher_model.PublisherModel], + ]: + r"""Return a callable for the get publisher model method over gRPC. + + Gets a Model Garden publisher model. + + Returns: + Callable[[~.GetPublisherModelRequest], + Awaitable[~.PublisherModel]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_publisher_model" not in self._stubs: + self._stubs["get_publisher_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelGardenService/GetPublisherModel", + request_serializer=model_garden_service.GetPublisherModelRequest.serialize, + response_deserializer=publisher_model.PublisherModel.deserialize, + ) + return self._stubs["get_publisher_model"] + + def close(self): + return self.grpc_channel.close() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None]: + r"""Return a callable for the delete_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["delete_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/DeleteOperation", + request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["delete_operation"] + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None]: + r"""Return a callable for the cancel_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_operation" not in self._stubs: + self._stubs["cancel_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/CancelOperation", + request_serializer=operations_pb2.CancelOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["cancel_operation"] + + @property + def wait_operation( + self, + ) -> Callable[[operations_pb2.WaitOperationRequest], None]: + r"""Return a callable for the wait_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["wait_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/WaitOperation", + request_serializer=operations_pb2.WaitOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["wait_operation"] + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse + ]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_locations" not in self._stubs: + self._stubs["list_locations"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/ListLocations", + request_serializer=locations_pb2.ListLocationsRequest.SerializeToString, + response_deserializer=locations_pb2.ListLocationsResponse.FromString, + ) + return self._stubs["list_locations"] + + @property + def get_location( + self, + ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_location" not in self._stubs: + self._stubs["get_location"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/GetLocation", + request_serializer=locations_pb2.GetLocationRequest.SerializeToString, + response_deserializer=locations_pb2.Location.FromString, + ) + return self._stubs["get_location"] + + @property + def set_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the set iam policy method over gRPC. + Sets the IAM access control policy on the specified + function. Replaces any existing policy. + Returns: + Callable[[~.SetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "set_iam_policy" not in self._stubs: + self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/SetIamPolicy", + request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["set_iam_policy"] + + @property + def get_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the get iam policy method over gRPC. + Gets the IAM access control policy for a function. + Returns an empty policy if the function exists and does + not have a policy set. + Returns: + Callable[[~.GetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_iam_policy" not in self._stubs: + self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/GetIamPolicy", + request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["get_iam_policy"] + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + iam_policy_pb2.TestIamPermissionsResponse, + ]: + r"""Return a callable for the test iam permissions method over gRPC. + Tests the specified permissions against the IAM access control + policy for a function. If the function does not exist, this will + return an empty set of permissions, not a NOT_FOUND error. + Returns: + Callable[[~.TestIamPermissionsRequest], + ~.TestIamPermissionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "test_iam_permissions" not in self._stubs: + self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/TestIamPermissions", + request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString, + response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString, + ) + return self._stubs["test_iam_permissions"] + + +__all__ = ("ModelGardenServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py index b8482823ca..0a4f9964d8 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py @@ -806,6 +806,7 @@ async def resume_schedule( request: Optional[Union[schedule_service.ResumeScheduleRequest, dict]] = None, *, name: Optional[str] = None, + catch_up: Optional[bool] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), @@ -856,6 +857,18 @@ async def sample_resume_schedule(): This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. + catch_up (:class:`bool`): + Optional. Whether to backfill missed runs when the + schedule is resumed from PAUSED state. If set to true, + all missed runs will be scheduled. New runs will be + scheduled after the backfill is complete. This will also + update + [Schedule.catch_up][google.cloud.aiplatform.v1beta1.Schedule.catch_up] + field. Default to false. + + This corresponds to the ``catch_up`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -865,7 +878,7 @@ async def sample_resume_schedule(): # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) + has_flattened_params = any([name, catch_up]) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -878,6 +891,8 @@ async def sample_resume_schedule(): # request, apply these. if name is not None: request.name = name + if catch_up is not None: + request.catch_up = catch_up # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py index c59b2af013..3e2ad48b3f 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py @@ -1154,6 +1154,7 @@ def resume_schedule( request: Optional[Union[schedule_service.ResumeScheduleRequest, dict]] = None, *, name: Optional[str] = None, + catch_up: Optional[bool] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), @@ -1204,6 +1205,18 @@ def sample_resume_schedule(): This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. + catch_up (bool): + Optional. Whether to backfill missed runs when the + schedule is resumed from PAUSED state. If set to true, + all missed runs will be scheduled. New runs will be + scheduled after the backfill is complete. This will also + update + [Schedule.catch_up][google.cloud.aiplatform.v1beta1.Schedule.catch_up] + field. Default to false. + + This corresponds to the ``catch_up`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -1213,7 +1226,7 @@ def sample_resume_schedule(): # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) + has_flattened_params = any([name, catch_up]) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -1230,6 +1243,8 @@ def sample_resume_schedule(): # request, apply these. if name is not None: request.name = name + if catch_up is not None: + request.catch_up = catch_up # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 8f761fce23..6e0aefb985 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -118,6 +118,9 @@ GetEndpointRequest, ListEndpointsRequest, ListEndpointsResponse, + MutateDeployedModelOperationMetadata, + MutateDeployedModelRequest, + MutateDeployedModelResponse, UndeployModelOperationMetadata, UndeployModelRequest, UndeployModelResponse, @@ -429,6 +432,7 @@ SearchMigratableResourcesResponse, ) from .model import ( + LargeModelReference, Model, ModelContainerSpec, ModelSourceInfo, @@ -449,6 +453,10 @@ from .model_evaluation_slice import ( ModelEvaluationSlice, ) +from .model_garden_service import ( + GetPublisherModelRequest, + PublisherModelView, +) from .model_monitoring import ( ModelMonitoringAlertConfig, ModelMonitoringConfig, @@ -529,6 +537,9 @@ PredictResponse, RawPredictRequest, ) +from .publisher_model import ( + PublisherModel, +) from .saved_query import ( SavedQuery, ) @@ -753,6 +764,9 @@ "GetEndpointRequest", "ListEndpointsRequest", "ListEndpointsResponse", + "MutateDeployedModelOperationMetadata", + "MutateDeployedModelRequest", + "MutateDeployedModelResponse", "UndeployModelOperationMetadata", "UndeployModelRequest", "UndeployModelResponse", @@ -1003,6 +1017,7 @@ "MigrateResourceResponse", "SearchMigratableResourcesRequest", "SearchMigratableResourcesResponse", + "LargeModelReference", "Model", "ModelContainerSpec", "ModelSourceInfo", @@ -1016,6 +1031,8 @@ "ModelDeploymentMonitoringObjectiveType", "ModelEvaluation", "ModelEvaluationSlice", + "GetPublisherModelRequest", + "PublisherModelView", "ModelMonitoringAlertConfig", "ModelMonitoringConfig", "ModelMonitoringObjectiveConfig", @@ -1084,6 +1101,7 @@ "PredictRequest", "PredictResponse", "RawPredictRequest", + "PublisherModel", "SavedQuery", "Schedule", "CreateScheduleRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py index 72b8f96d6b..38453222e6 100644 --- a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py @@ -48,7 +48,9 @@ class AcceleratorType(proto.Enum): NVIDIA_TESLA_A100 (8): Nvidia Tesla A100 GPU. NVIDIA_A100_80GB (9): - Nvidia A2 Ultra GPU. + Nvidia A100 80GB GPU. + NVIDIA_L4 (11): + Nvidia L4 GPU. TPU_V2 (6): TPU v2. TPU_V3 (7): @@ -64,6 +66,7 @@ class AcceleratorType(proto.Enum): NVIDIA_TESLA_T4 = 5 NVIDIA_TESLA_A100 = 8 NVIDIA_A100_80GB = 9 + NVIDIA_L4 = 11 TPU_V2 = 6 TPU_V3 = 7 TPU_V4_POD = 10 diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index fbb96e5647..5d357138b3 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -242,10 +242,10 @@ class BatchPredictionJob(proto.Message): disable_container_logging (bool): For custom-trained Models and AutoML Tabular Models, the container of the DeployedModel instances will send - ``stderr`` and ``stdout`` streams to Stackdriver Logging by + ``stderr`` and ``stdout`` streams to Cloud Logging by default. Please note that the logs incur cost, which are subject to `Cloud Logging - pricing `__. + pricing `__. User can disable container logging by setting this flag to true. diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index d74757647d..ca2f379cdb 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -268,6 +268,13 @@ class CustomJobSpec(proto.Message): [Trial.web_access_uris][google.cloud.aiplatform.v1beta1.Trial.web_access_uris] (within [HyperparameterTuningJob.trials][google.cloud.aiplatform.v1beta1.HyperparameterTuningJob.trials]). + experiment (str): + Optional. The Experiment associated with this job. Format: + ``projects/{project}/locations/{location}/metadataStores/{metadataStores}/contexts/{experiment-name}`` + experiment_run (str): + Optional. The Experiment Run associated with this job. + Format: + ``projects/{project}/locations/{location}/metadataStores/{metadataStores}/contexts/{experiment-name}-{experiment-run-name}`` """ worker_pool_specs: MutableSequence["WorkerPoolSpec"] = proto.RepeatedField( @@ -309,6 +316,14 @@ class CustomJobSpec(proto.Message): proto.BOOL, number=16, ) + experiment: str = proto.Field( + proto.STRING, + number=17, + ) + experiment_run: str = proto.Field( + proto.STRING, + number=18, + ) class WorkerPoolSpec(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 7daa228cd3..bcc68d0f38 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -91,8 +91,7 @@ class Dataset(proto.Message): title. saved_queries (MutableSequence[google.cloud.aiplatform_v1beta1.types.SavedQuery]): All SavedQueries belong to the Dataset will be returned in - List/Get Dataset response. The - [annotation_specs][SavedQuery.annotation_specs] field will + List/Get Dataset response. The annotation_specs field will not be populated except for UI cases which will only use [annotation_spec_count][google.cloud.aiplatform.v1beta1.SavedQuery.annotation_spec_count]. In CreateDataset request, a SavedQuery is created together @@ -266,10 +265,9 @@ class ExportDataConfig(proto.Message): This field is a member of `oneof`_ ``split``. annotations_filter (str): - A filter on Annotations of the Dataset. Only Annotations on - to-be-exported DataItems(specified by [data_items_filter][]) - that match this filter will be exported. The filter syntax - is the same as in + An expression for filtering what part of the Dataset is to + be exported. Only Annotations that match this filter will be + exported. The filter syntax is the same as in [ListAnnotations][google.cloud.aiplatform.v1beta1.DatasetService.ListAnnotations]. """ diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 41a0da28ce..8270ad8b09 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -124,7 +124,8 @@ class Endpoint(proto.Message): model_deployment_monitoring_job (str): Output only. Resource name of the Model Monitoring job associated with this Endpoint if monitoring is enabled by - [CreateModelDeploymentMonitoringJob][]. Format: + [JobService.CreateModelDeploymentMonitoringJob][google.cloud.aiplatform.v1beta1.JobService.CreateModelDeploymentMonitoringJob]. + Format: ``projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}`` predict_request_response_logging_config (google.cloud.aiplatform_v1beta1.types.PredictRequestResponseLoggingConfig): Configures the request-response logging for @@ -291,22 +292,20 @@ class DeployedModel(proto.Message): account. enable_container_logging (bool): If true, the container of the DeployedModel instances will - send ``stderr`` and ``stdout`` streams to Stackdriver - Logging. + send ``stderr`` and ``stdout`` streams to Cloud Logging. Only supported for custom-trained Models and AutoML Tabular Models. enable_access_logging (bool): If true, online prediction access logs are - sent to StackDriver Logging. + sent to Cloud Logging. These logs are like standard server access logs, containing information like timestamp and latency for each prediction request. - Note that Stackdriver logs may incur a cost, - especially if your project receives prediction - requests at a high queries per second rate - (QPS). Estimate your costs before enabling this - option. + Note that logs may incur a cost, especially if + your project receives prediction requests at a + high queries per second rate (QPS). Estimate + your costs before enabling this option. private_endpoints (google.cloud.aiplatform_v1beta1.types.PrivateEndpoints): Output only. Provide paths for users to send predict/explain/health requests directly to the deployed diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index bca77343c7..f38c53764a 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -40,6 +40,9 @@ "UndeployModelRequest", "UndeployModelResponse", "UndeployModelOperationMetadata", + "MutateDeployedModelRequest", + "MutateDeployedModelResponse", + "MutateDeployedModelOperationMetadata", }, ) @@ -403,4 +406,81 @@ class UndeployModelOperationMetadata(proto.Message): ) +class MutateDeployedModelRequest(proto.Message): + r"""Request message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel]. + + Attributes: + endpoint (str): + Required. The name of the Endpoint resource into which to + mutate a DeployedModel. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + deployed_model (google.cloud.aiplatform_v1beta1.types.DeployedModel): + Required. The DeployedModel to be mutated within the + Endpoint. Only the following fields can be mutated: + + - ``min_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1beta1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1beta1.AutomaticResources] + - ``max_replica_count`` in either + [DedicatedResources][google.cloud.aiplatform.v1beta1.DedicatedResources] + or + [AutomaticResources][google.cloud.aiplatform.v1beta1.AutomaticResources] + - [autoscaling_metric_specs][google.cloud.aiplatform.v1beta1.DedicatedResources.autoscaling_metric_specs] + - ``disable_container_logging`` (v1 only) + - ``enable_container_logging`` (v1beta1 only) + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + """ + + endpoint: str = proto.Field( + proto.STRING, + number=1, + ) + deployed_model: gca_endpoint.DeployedModel = proto.Field( + proto.MESSAGE, + number=2, + message=gca_endpoint.DeployedModel, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=4, + message=field_mask_pb2.FieldMask, + ) + + +class MutateDeployedModelResponse(proto.Message): + r"""Response message for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel]. + + Attributes: + deployed_model (google.cloud.aiplatform_v1beta1.types.DeployedModel): + The DeployedModel that's being mutated. + """ + + deployed_model: gca_endpoint.DeployedModel = proto.Field( + proto.MESSAGE, + number=1, + message=gca_endpoint.DeployedModel, + ) + + +class MutateDeployedModelOperationMetadata(proto.Message): + r"""Runtime operation information for + [EndpointService.MutateDeployedModel][google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel]. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata: operation.GenericOperationMetadata = proto.Field( + proto.MESSAGE, + number=1, + message=operation.GenericOperationMetadata, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/evaluated_annotation.py b/google/cloud/aiplatform_v1beta1/types/evaluated_annotation.py index 958668a070..5961e27165 100644 --- a/google/cloud/aiplatform_v1beta1/types/evaluated_annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/evaluated_annotation.py @@ -89,10 +89,6 @@ class EvaluatedAnnotation(proto.Message): ancestor ModelEvaluation. The EvaluatedDataItemView consists of all ground truths and predictions on [data_item_payload][google.cloud.aiplatform.v1beta1.EvaluatedAnnotation.data_item_payload]. - - Can be passed in - [GetEvaluatedDataItemView's][ModelService.GetEvaluatedDataItemView][] - [id][GetEvaluatedDataItemViewRequest.id]. explanations (MutableSequence[google.cloud.aiplatform_v1beta1.types.EvaluatedAnnotationExplanation]): Explanations of [predictions][google.cloud.aiplatform.v1beta1.EvaluatedAnnotation.predictions]. diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index cc683f8c80..d920785fa2 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -705,22 +705,27 @@ class Examples(proto.Message): Attributes: nearest_neighbor_search_config (google.protobuf.struct_pb2.Value): - The configuration for the generated index, the semantics are - the same as + The full configuration for the generated index, the + semantics are the same as [metadata][google.cloud.aiplatform.v1beta1.Index.metadata] - and should match NearestNeighborSearchConfig. + and should match + `NearestNeighborSearchConfig `__. This field is a member of `oneof`_ ``config``. presets (google.cloud.aiplatform_v1beta1.types.Presets): - Preset config based on the desired query - speed-precision trade-off and modality + Simplified preset configuration, which + automatically sets configuration values based on + the desired query speed-precision trade-off and + modality. This field is a member of `oneof`_ ``config``. gcs_source (google.cloud.aiplatform_v1beta1.types.GcsSource): - The Cloud Storage location for the input - instances. + The Cloud Storage locations that contain the + instances to be indexed for approximate nearest + neighbor search. neighbor_count (int): - The number of neighbors to return. + The number of neighbors to return when + querying for examples. """ nearest_neighbor_search_config: struct_pb2.Value = proto.Field( @@ -753,13 +758,18 @@ class Presets(proto.Message): Attributes: query (google.cloud.aiplatform_v1beta1.types.Presets.Query): - Preset option controlling parameters for - query speed-precision trade-off + Preset option controlling parameters for speed-precision + trade-off when querying for examples. If omitted, defaults + to ``PRECISE``. This field is a member of `oneof`_ ``_query``. modality (google.cloud.aiplatform_v1beta1.types.Presets.Modality): - Preset option controlling parameters for - different modalities + The modality of the uploaded model, which + automatically configures the distance + measurement and feature normalization for the + underlying example index and queries. If your + model does not precisely fit one of these types, + it is okay to choose the closest type. """ class Query(proto.Enum): @@ -769,8 +779,7 @@ class Query(proto.Enum): Values: PRECISE (0): More precise neighbors as a trade-off against - slower response. This is also the default value - (field-number 0). + slower response. FAST (1): Faster response as a trade-off against less precise neighbors. @@ -819,10 +828,9 @@ class ExplanationSpecOverride(proto.Message): Attributes: parameters (google.cloud.aiplatform_v1beta1.types.ExplanationParameters): - The parameters to be overridden. Note that the - [method][google.cloud.aiplatform.v1beta1.ExplanationParameters.method] - cannot be changed. If not specified, no parameter is - overridden. + The parameters to be overridden. Note that + the attribution method cannot be changed. If not + specified, no parameter is overridden. metadata (google.cloud.aiplatform_v1beta1.types.ExplanationMetadataOverride): The metadata to be overridden. If not specified, no metadata is overridden. diff --git a/google/cloud/aiplatform_v1beta1/types/feature.py b/google/cloud/aiplatform_v1beta1/types/feature.py index 735e9e2168..af0cf11183 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature.py +++ b/google/cloud/aiplatform_v1beta1/types/feature.py @@ -100,8 +100,8 @@ class Feature(proto.Message): If set to true, all types of data monitoring are disabled despite the config on EntityType. monitoring_stats (MutableSequence[google.cloud.aiplatform_v1beta1.types.FeatureStatsAnomaly]): - Output only. A list of historical [Snapshot - Analysis][FeaturestoreMonitoringConfig.SnapshotAnalysis] + Output only. A list of historical + [SnapshotAnalysis][google.cloud.aiplatform.v1beta1.FeaturestoreMonitoringConfig.SnapshotAnalysis] stats requested by user, sorted by [FeatureStatsAnomaly.start_time][google.cloud.aiplatform.v1beta1.FeatureStatsAnomaly.start_time] descending. @@ -147,11 +147,11 @@ class ValueType(proto.Enum): BYTES = 13 class MonitoringStatsAnomaly(proto.Message): - r"""A list of historical [Snapshot - Analysis][FeaturestoreMonitoringConfig.SnapshotAnalysis] or [Import - Feature Analysis] - [FeaturestoreMonitoringConfig.ImportFeatureAnalysis] stats requested - by user, sorted by + r"""A list of historical + [SnapshotAnalysis][google.cloud.aiplatform.v1beta1.FeaturestoreMonitoringConfig.SnapshotAnalysis] + or + [ImportFeaturesAnalysis][google.cloud.aiplatform.v1beta1.FeaturestoreMonitoringConfig.ImportFeaturesAnalysis] + stats requested by user, sorted by [FeatureStatsAnomaly.start_time][google.cloud.aiplatform.v1beta1.FeatureStatsAnomaly.start_time] descending. diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py index f9a9a77395..ed3531742b 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py @@ -177,7 +177,7 @@ class Header(proto.Message): ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}``. feature_descriptors (MutableSequence[google.cloud.aiplatform_v1beta1.types.ReadFeatureValuesResponse.FeatureDescriptor]): List of Feature metadata corresponding to each piece of - [ReadFeatureValuesResponse.data][]. + [ReadFeatureValuesResponse.EntityView.data][google.cloud.aiplatform.v1beta1.ReadFeatureValuesResponse.EntityView.data]. """ entity_type: str = proto.Field( diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py index f351a0d137..47c61527bf 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py @@ -270,7 +270,7 @@ class UpdateFeaturestoreRequest(proto.Message): - ``labels`` - ``online_serving_config.fixed_node_count`` - ``online_serving_config.scaling`` - - ``online_storage_ttl_days`` (available in Preview) + - ``online_storage_ttl_days`` """ featurestore: gca_featurestore.Featurestore = proto.Field( @@ -1076,7 +1076,7 @@ class UpdateEntityTypeRequest(proto.Message): - ``monitoring_config.import_features_analysis.anomaly_detection_baseline`` - ``monitoring_config.numerical_threshold_config.value`` - ``monitoring_config.categorical_threshold_config.value`` - - ``offline_storage_ttl_days`` (available in Preview) + - ``offline_storage_ttl_days`` """ entity_type: gca_entity_type.EntityType = proto.Field( diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py index aac23e7e5d..6ced26a6f1 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py @@ -259,15 +259,15 @@ class DeployedIndex(proto.Message): efficiency. enable_access_logging (bool): Optional. If true, private endpoint's access - logs are sent to StackDriver Logging. + logs are sent to Cloud Logging. These logs are like standard server access logs, containing information like timestamp and latency for each MatchRequest. - Note that Stackdriver logs may incur a cost, - especially if the deployed index receives a high - queries per second rate (QPS). Estimate your - costs before enabling this option. + Note that logs may incur a cost, especially if + the deployed index receives a high queries per + second rate (QPS). Estimate your costs before + enabling this option. deployed_index_auth_config (google.cloud.aiplatform_v1beta1.types.DeployedIndexAuthConfig): Optional. If set, the authentication is enabled for the private endpoint. diff --git a/google/cloud/aiplatform_v1beta1/types/match_service.py b/google/cloud/aiplatform_v1beta1/types/match_service.py index 255e2edf6c..3e215dfd16 100644 --- a/google/cloud/aiplatform_v1beta1/types/match_service.py +++ b/google/cloud/aiplatform_v1beta1/types/match_service.py @@ -42,8 +42,8 @@ class FindNeighborsRequest(proto.Message): Required. The name of the index endpoint. Format: ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` deployed_index_id (str): - The ID of the DeploydIndex that will serve the request. This - request is sent to a specific IndexEndpoint, as per the + The ID of the DeployedIndex that will serve the request. + This request is sent to a specific IndexEndpoint, as per the IndexEndpoint.network. That IndexEndpoint also has IndexEndpoint.deployed_indexes, and each such index has a DeployedIndex.id field. The value of the field below must @@ -208,7 +208,7 @@ class ReadIndexDatapointsRequest(proto.Message): Required. The name of the index endpoint. Format: ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` deployed_index_id (str): - The ID of the DeploydIndex that will serve + The ID of the DeployedIndex that will serve the request. ids (MutableSequence[str]): IDs of the datapoints to be searched for. diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index 272280ec15..ae12be9331 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -31,6 +31,7 @@ package="google.cloud.aiplatform.v1beta1", manifest={ "Model", + "LargeModelReference", "PredictSchemata", "ModelContainerSpec", "Port", @@ -540,6 +541,24 @@ class OriginalModelInfo(proto.Message): ) +class LargeModelReference(proto.Message): + r"""Contains information about the Large Model. + + Attributes: + name (str): + Required. The unique name of the large + Foundation or pre-built model. Like + "chat-panda", "text-panda". Or model name with + version ID, like "chat-panda-001", + "text-panda-005", etc. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + class PredictSchemata(proto.Message): r"""Contains the schemata used in Model's predictions and explanations via @@ -925,12 +944,15 @@ class ModelSourceType(proto.Enum): MODEL_GARDEN (4): The Model is saved or tuned from Model Garden. + GENIE (5): + The Model is saved or tuned from Genie. """ MODEL_SOURCE_TYPE_UNSPECIFIED = 0 AUTOML = 1 CUSTOM = 2 BQML = 3 MODEL_GARDEN = 4 + GENIE = 5 source_type: ModelSourceType = proto.Field( proto.ENUM, diff --git a/google/cloud/aiplatform_v1beta1/types/model_garden_service.py b/google/cloud/aiplatform_v1beta1/types/model_garden_service.py new file mode 100644 index 0000000000..71419bf136 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/model_garden_service.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "PublisherModelView", + "GetPublisherModelRequest", + }, +) + + +class PublisherModelView(proto.Enum): + r"""View enumeration of PublisherModel. + + Values: + PUBLISHER_MODEL_VIEW_UNSPECIFIED (0): + The default / unset value. The API will + default to the BASIC view. + PUBLISHER_MODEL_VIEW_BASIC (1): + Include basic metadata about the publisher + model, but not the full contents. + PUBLISHER_MODEL_VIEW_FULL (2): + Include everything. + PUBLISHER_MODEL_VERSION_VIEW_BASIC (3): + Include: VersionId, ModelVersionExternalName, + and SupportedActions. + """ + PUBLISHER_MODEL_VIEW_UNSPECIFIED = 0 + PUBLISHER_MODEL_VIEW_BASIC = 1 + PUBLISHER_MODEL_VIEW_FULL = 2 + PUBLISHER_MODEL_VERSION_VIEW_BASIC = 3 + + +class GetPublisherModelRequest(proto.Message): + r"""Request message for + [ModelGardenService.GetPublisherModel][google.cloud.aiplatform.v1beta1.ModelGardenService.GetPublisherModel] + + Attributes: + name (str): + Required. The name of the PublisherModel resource. Format: + ``publishers/{publisher}/models/{publisher_model}`` + language_code (str): + Optional. The IETF BCP-47 language code + representing the language in which the publisher + model's text information should be written in + (see go/bcp47). + view (google.cloud.aiplatform_v1beta1.types.PublisherModelView): + Optional. PublisherModel view specifying + which fields to read. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + language_code: str = proto.Field( + proto.STRING, + number=2, + ) + view: "PublisherModelView" = proto.Field( + proto.ENUM, + number=3, + enum="PublisherModelView", + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index 19d1dbf1a0..948653bf7d 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -299,8 +299,10 @@ class ListModelVersionsRequest(proto.Message): The standard list page size. page_token (str): The standard list page token. Typically obtained via - [ListModelVersionsResponse.next_page_token][google.cloud.aiplatform.v1beta1.ListModelVersionsResponse.next_page_token] - of the previous [ModelService.ListModelversions][] call. + [next_page_token][google.cloud.aiplatform.v1beta1.ListModelVersionsResponse.next_page_token] + of the previous + [ListModelVersions][google.cloud.aiplatform.v1beta1.ModelService.ListModelVersions] + call. filter (str): An expression for filtering the results of the request. For field names both snake_case and camelCase are supported. diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_job.py b/google/cloud/aiplatform_v1beta1/types/pipeline_job.py index 8e63e29840..9b000f7962 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_job.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_job.py @@ -374,7 +374,8 @@ class PipelineTaskDetail(proto.Message): task is at the root level. task_name (str): Output only. The user specified name of the task that is - defined in [PipelineJob.spec][]. + defined in + [pipeline_spec][google.cloud.aiplatform.v1beta1.PipelineJob.pipeline_spec]. create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Task create time. start_time (google.protobuf.timestamp_pb2.Timestamp): diff --git a/google/cloud/aiplatform_v1beta1/types/publisher_model.py b/google/cloud/aiplatform_v1beta1/types/publisher_model.py new file mode 100644 index 0000000000..f596150a77 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/publisher_model.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import model + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "PublisherModel", + }, +) + + +class PublisherModel(proto.Message): + r"""A Model Garden Publisher Model. + + Attributes: + name (str): + Output only. The resource name of the + PublisherModel. + version_id (str): + Output only. Immutable. The version ID of the + PublisherModel. A new version is committed when + a new model version is uploaded under an + existing model id. It is an auto-incrementing + decimal number in string representation. + open_source_category (google.cloud.aiplatform_v1beta1.types.PublisherModel.OpenSourceCategory): + Required. Indicates the open source category + of the publisher model. + supported_actions (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction): + Optional. Supported call-to-action options. + frameworks (MutableSequence[str]): + Optional. Additional information about the + model's Frameworks. + publisher_model_template (str): + Optional. Output only. Immutable. Used to + indicate this model has a publisher model and + provide the template of the publisher model + resource name. + predict_schemata (google.cloud.aiplatform_v1beta1.types.PredictSchemata): + Optional. The schemata that describes formats of the + PublisherModel's predictions and explanations as given and + returned via + [PredictionService.Predict][google.cloud.aiplatform.v1beta1.PredictionService.Predict]. + """ + + class OpenSourceCategory(proto.Enum): + r"""An enum representing the open source category of a + PublisherModel. + + Values: + OPEN_SOURCE_CATEGORY_UNSPECIFIED (0): + The open source category is unspecified, + which should not be used. + PROPRIETARY (1): + Used to indicate the PublisherModel is not + open sourced. + GOOGLE_OWNED_OSS_WITH_GOOGLE_CHECKPOINT (2): + Used to indicate the PublisherModel is a + Google-owned open source model w/ Google + checkpoint. + THIRD_PARTY_OWNED_OSS_WITH_GOOGLE_CHECKPOINT (3): + Used to indicate the PublisherModel is a + 3p-owned open source model w/ Google checkpoint. + GOOGLE_OWNED_OSS (4): + Used to indicate the PublisherModel is a + Google-owned pure open source model. + THIRD_PARTY_OWNED_OSS (5): + Used to indicate the PublisherModel is a + 3p-owned pure open source model. + """ + OPEN_SOURCE_CATEGORY_UNSPECIFIED = 0 + PROPRIETARY = 1 + GOOGLE_OWNED_OSS_WITH_GOOGLE_CHECKPOINT = 2 + THIRD_PARTY_OWNED_OSS_WITH_GOOGLE_CHECKPOINT = 3 + GOOGLE_OWNED_OSS = 4 + THIRD_PARTY_OWNED_OSS = 5 + + class ResourceReference(proto.Message): + r"""Reference to a resource. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + uri (str): + The URI of the resource. + + This field is a member of `oneof`_ ``reference``. + resource_name (str): + The resource name of the GCP resource. + + This field is a member of `oneof`_ ``reference``. + """ + + uri: str = proto.Field( + proto.STRING, + number=1, + oneof="reference", + ) + resource_name: str = proto.Field( + proto.STRING, + number=2, + oneof="reference", + ) + + class Documentation(proto.Message): + r"""A named piece of documentation. + + Attributes: + title (str): + Required. E.g., OVERVIEW, USE CASES, + DOCUMENTATION, SDK & SAMPLES, JAVA, NODE.JS, + etc.. + content (str): + Required. Content of this piece of document + (in Markdown format). + """ + + title: str = proto.Field( + proto.STRING, + number=1, + ) + content: str = proto.Field( + proto.STRING, + number=2, + ) + + class CallToAction(proto.Message): + r"""Actions could take on this Publisher Model. + + Attributes: + view_rest_api (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.ViewRestApi): + Optional. To view Rest API docs. + open_notebook (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.RegionalResourceReferences): + Optional. Open notebook of the + PublisherModel. + create_application (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.RegionalResourceReferences): + Optional. Create application using the + PublisherModel. + open_fine_tuning_pipeline (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.RegionalResourceReferences): + Optional. Open fine-tuning pipeline of the + PublisherModel. + open_prompt_tuning_pipeline (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.RegionalResourceReferences): + Optional. Open prompt-tuning pipeline of the + PublisherModel. + open_genie (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.RegionalResourceReferences): + Optional. Open Genie / Playground. + deploy (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.Deploy): + Optional. Deploy the PublisherModel to Vertex + Endpoint. + open_generation_ai_studio (google.cloud.aiplatform_v1beta1.types.PublisherModel.CallToAction.RegionalResourceReferences): + Optional. Open in Generation AI Studio. + """ + + class RegionalResourceReferences(proto.Message): + r"""The regional resource name or the URI. Key is region, e.g., + us-central1, europe-west2, global, etc.. + + Attributes: + references (MutableMapping[str, google.cloud.aiplatform_v1beta1.types.PublisherModel.ResourceReference]): + Required. + title (str): + Required. The title of the regional resource + reference. + """ + + references: MutableMapping[ + str, "PublisherModel.ResourceReference" + ] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=1, + message="PublisherModel.ResourceReference", + ) + title: str = proto.Field( + proto.STRING, + number=2, + ) + + class ViewRestApi(proto.Message): + r"""Rest API docs. + + Attributes: + documentations (MutableSequence[google.cloud.aiplatform_v1beta1.types.PublisherModel.Documentation]): + Required. + title (str): + Required. The title of the view rest API. + """ + + documentations: MutableSequence[ + "PublisherModel.Documentation" + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="PublisherModel.Documentation", + ) + title: str = proto.Field( + proto.STRING, + number=2, + ) + + class Deploy(proto.Message): + r"""Model metadata that is needed for UploadModel or + DeployModel/CreateEndpoint requests. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + dedicated_resources (google.cloud.aiplatform_v1beta1.types.DedicatedResources): + A description of resources that are dedicated + to the DeployedModel, and that need a higher + degree of manual configuration. + + This field is a member of `oneof`_ ``prediction_resources``. + automatic_resources (google.cloud.aiplatform_v1beta1.types.AutomaticResources): + A description of resources that to large + degree are decided by Vertex AI, and require + only a modest additional configuration. + + This field is a member of `oneof`_ ``prediction_resources``. + shared_resources (str): + The resource name of the shared DeploymentResourcePool to + deploy on. Format: + ``projects/{project}/locations/{location}/deploymentResourcePools/{deployment_resource_pool}`` + + This field is a member of `oneof`_ ``prediction_resources``. + model_display_name (str): + Optional. Default model display name. + large_model_reference (google.cloud.aiplatform_v1beta1.types.LargeModelReference): + Optional. Large model reference. When this is set, + model_artifact_spec is not needed. + container_spec (google.cloud.aiplatform_v1beta1.types.ModelContainerSpec): + Optional. The specification of the container + that is to be used when deploying this Model in + Vertex AI. Not present for Large Models. + artifact_uri (str): + Optional. The path to the directory + containing the Model artifact and any of its + supporting files. + title (str): + Required. The title of the regional resource + reference. + """ + + dedicated_resources: machine_resources.DedicatedResources = proto.Field( + proto.MESSAGE, + number=5, + oneof="prediction_resources", + message=machine_resources.DedicatedResources, + ) + automatic_resources: machine_resources.AutomaticResources = proto.Field( + proto.MESSAGE, + number=6, + oneof="prediction_resources", + message=machine_resources.AutomaticResources, + ) + shared_resources: str = proto.Field( + proto.STRING, + number=7, + oneof="prediction_resources", + ) + model_display_name: str = proto.Field( + proto.STRING, + number=1, + ) + large_model_reference: model.LargeModelReference = proto.Field( + proto.MESSAGE, + number=2, + message=model.LargeModelReference, + ) + container_spec: model.ModelContainerSpec = proto.Field( + proto.MESSAGE, + number=3, + message=model.ModelContainerSpec, + ) + artifact_uri: str = proto.Field( + proto.STRING, + number=4, + ) + title: str = proto.Field( + proto.STRING, + number=8, + ) + + view_rest_api: "PublisherModel.CallToAction.ViewRestApi" = proto.Field( + proto.MESSAGE, + number=1, + message="PublisherModel.CallToAction.ViewRestApi", + ) + open_notebook: "PublisherModel.CallToAction.RegionalResourceReferences" = ( + proto.Field( + proto.MESSAGE, + number=2, + message="PublisherModel.CallToAction.RegionalResourceReferences", + ) + ) + create_application: "PublisherModel.CallToAction.RegionalResourceReferences" = ( + proto.Field( + proto.MESSAGE, + number=3, + message="PublisherModel.CallToAction.RegionalResourceReferences", + ) + ) + open_fine_tuning_pipeline: "PublisherModel.CallToAction.RegionalResourceReferences" = proto.Field( + proto.MESSAGE, + number=4, + message="PublisherModel.CallToAction.RegionalResourceReferences", + ) + open_prompt_tuning_pipeline: "PublisherModel.CallToAction.RegionalResourceReferences" = proto.Field( + proto.MESSAGE, + number=5, + message="PublisherModel.CallToAction.RegionalResourceReferences", + ) + open_genie: "PublisherModel.CallToAction.RegionalResourceReferences" = ( + proto.Field( + proto.MESSAGE, + number=6, + message="PublisherModel.CallToAction.RegionalResourceReferences", + ) + ) + deploy: "PublisherModel.CallToAction.Deploy" = proto.Field( + proto.MESSAGE, + number=7, + message="PublisherModel.CallToAction.Deploy", + ) + open_generation_ai_studio: "PublisherModel.CallToAction.RegionalResourceReferences" = proto.Field( + proto.MESSAGE, + number=8, + message="PublisherModel.CallToAction.RegionalResourceReferences", + ) + + name: str = proto.Field( + proto.STRING, + number=1, + ) + version_id: str = proto.Field( + proto.STRING, + number=2, + ) + open_source_category: OpenSourceCategory = proto.Field( + proto.ENUM, + number=7, + enum=OpenSourceCategory, + ) + supported_actions: CallToAction = proto.Field( + proto.MESSAGE, + number=19, + message=CallToAction, + ) + frameworks: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=23, + ) + publisher_model_template: str = proto.Field( + proto.STRING, + number=30, + ) + predict_schemata: model.PredictSchemata = proto.Field( + proto.MESSAGE, + number=31, + message=model.PredictSchemata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/schedule.py b/google/cloud/aiplatform_v1beta1/types/schedule.py index 36d013a29c..15b0a82f90 100644 --- a/google/cloud/aiplatform_v1beta1/types/schedule.py +++ b/google/cloud/aiplatform_v1beta1/types/schedule.py @@ -91,6 +91,9 @@ class Schedule(proto.Message): create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Schedule was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Schedule was + updated. next_run_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Schedule should schedule the next run. Having a next_run_time in the past means the @@ -103,7 +106,10 @@ class Schedule(proto.Message): last resumed. Unset if never resumed from pause. max_concurrent_run_count (int): Required. Maximum number of runs that can be - executed concurrently for this Schedule. + started concurrently for this Schedule. This is + the limit for starting the scheduled requests + and not the execution of the operations/jobs + created by the requests (if applicable). allow_queueing (bool): Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached. If set to true, new @@ -114,6 +120,13 @@ class Schedule(proto.Message): If set to true, all missed runs will be scheduled. New runs will be scheduled after the backfill is complete. Default to false. + last_scheduled_run_response (google.cloud.aiplatform_v1beta1.types.Schedule.RunResponse): + Output only. Response of the last scheduled + run. This is the response for starting the + scheduled requests and not the execution of the + operations/jobs created by the requests (if + applicable). Unset if no run has been scheduled + yet. """ class State(proto.Enum): @@ -140,6 +153,27 @@ class State(proto.Enum): PAUSED = 2 COMPLETED = 3 + class RunResponse(proto.Message): + r"""Status of a scheduled run. + + Attributes: + scheduled_run_time (google.protobuf.timestamp_pb2.Timestamp): + The scheduled run time based on the + user-specified schedule. + run_response (str): + The response of the scheduled run. + """ + + scheduled_run_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=1, + message=timestamp_pb2.Timestamp, + ) + run_response: str = proto.Field( + proto.STRING, + number=2, + ) + cron: str = proto.Field( proto.STRING, number=10, @@ -189,6 +223,11 @@ class State(proto.Enum): number=6, message=timestamp_pb2.Timestamp, ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=19, + message=timestamp_pb2.Timestamp, + ) next_run_time: timestamp_pb2.Timestamp = proto.Field( proto.MESSAGE, number=7, @@ -216,6 +255,11 @@ class State(proto.Enum): proto.BOOL, number=13, ) + last_scheduled_run_response: RunResponse = proto.Field( + proto.MESSAGE, + number=18, + message=RunResponse, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_experiment.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_experiment.py index b845b79ca0..be4f09d7b5 100644 --- a/google/cloud/aiplatform_v1beta1/types/tensorboard_experiment.py +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_experiment.py @@ -54,7 +54,7 @@ class TensorboardExperiment(proto.Message): The labels with user-defined metadata to organize your Datasets. - Label keys and values can be no longer than 64 characters + Label keys and values cannot be longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. No more than 64 user labels can be @@ -62,13 +62,13 @@ class TensorboardExperiment(proto.Message): See https://goo.gl/xmQnxf for more information and examples of labels. System reserved label keys are prefixed with - "aiplatform.googleapis.com/" and are immutable. Following - system labels exist for each Dataset: + ``aiplatform.googleapis.com/`` and are immutable. The + following system labels exist for each Dataset: - - "aiplatform.googleapis.com/dataset_metadata_schema": - - - output only, its value is the - [metadata_schema's][metadata_schema_uri] title. + - ``aiplatform.googleapis.com/dataset_metadata_schema``: + output only. Its value is the + [metadata_schema's][google.cloud.aiplatform.v1beta1.Dataset.metadata_schema_uri] + title. etag (str): Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update diff --git a/google/cloud/aiplatform_v1beta1/types/tensorboard_service.py b/google/cloud/aiplatform_v1beta1/types/tensorboard_service.py index 3bf710c683..fa5eeee51f 100644 --- a/google/cloud/aiplatform_v1beta1/types/tensorboard_service.py +++ b/google/cloud/aiplatform_v1beta1/types/tensorboard_service.py @@ -1226,12 +1226,12 @@ class ExportTensorboardTimeSeriesDataRequest(proto.Message): 10000. Values above 10000 are coerced to 10000. page_token (str): A page token, received from a previous - [TensorboardService.ExportTensorboardTimeSeries][] call. - Provide this to retrieve the subsequent page. + [ExportTensorboardTimeSeriesData][google.cloud.aiplatform.v1beta1.TensorboardService.ExportTensorboardTimeSeriesData] + call. Provide this to retrieve the subsequent page. When paginating, all other parameters provided to - [TensorboardService.ExportTensorboardTimeSeries][] must - match the call that provided the page token. + [ExportTensorboardTimeSeriesData][google.cloud.aiplatform.v1beta1.TensorboardService.ExportTensorboardTimeSeriesData] + must match the call that provided the page token. order_by (str): Field to use to sort the TensorboardTimeSeries' data. By default, @@ -1270,9 +1270,9 @@ class ExportTensorboardTimeSeriesDataResponse(proto.Message): The returned time series data points. next_page_token (str): A token, which can be sent as - [ExportTensorboardTimeSeriesRequest.page_token][] to - retrieve the next page. If this field is omitted, there are - no subsequent pages. + [page_token][google.cloud.aiplatform.v1beta1.ExportTensorboardTimeSeriesDataRequest.page_token] + to retrieve the next page. If this field is omitted, there + are no subsequent pages. """ @property diff --git a/samples/generated_samples/aiplatform_v1_generated_endpoint_service_mutate_deployed_model_async.py b/samples/generated_samples/aiplatform_v1_generated_endpoint_service_mutate_deployed_model_async.py new file mode 100644 index 0000000000..789d6a52ce --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_endpoint_service_mutate_deployed_model_async.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for MutateDeployedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_EndpointService_MutateDeployedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_mutate_deployed_model(): + # Create a client + client = aiplatform_v1.EndpointServiceAsyncClient() + + # Initialize request argument(s) + deployed_model = aiplatform_v1.DeployedModel() + deployed_model.dedicated_resources.min_replica_count = 1803 + deployed_model.model = "model_value" + + request = aiplatform_v1.MutateDeployedModelRequest( + endpoint="endpoint_value", + deployed_model=deployed_model, + ) + + # Make the request + operation = client.mutate_deployed_model(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_EndpointService_MutateDeployedModel_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_endpoint_service_mutate_deployed_model_sync.py b/samples/generated_samples/aiplatform_v1_generated_endpoint_service_mutate_deployed_model_sync.py new file mode 100644 index 0000000000..6653910a7f --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_endpoint_service_mutate_deployed_model_sync.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for MutateDeployedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_EndpointService_MutateDeployedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_mutate_deployed_model(): + # Create a client + client = aiplatform_v1.EndpointServiceClient() + + # Initialize request argument(s) + deployed_model = aiplatform_v1.DeployedModel() + deployed_model.dedicated_resources.min_replica_count = 1803 + deployed_model.model = "model_value" + + request = aiplatform_v1.MutateDeployedModelRequest( + endpoint="endpoint_value", + deployed_model=deployed_model, + ) + + # Make the request + operation = client.mutate_deployed_model(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_EndpointService_MutateDeployedModel_sync] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_async.py new file mode 100644 index 0000000000..7ff9ff385b --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_async.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for MutateDeployedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_EndpointService_MutateDeployedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +async def sample_mutate_deployed_model(): + # Create a client + client = aiplatform_v1beta1.EndpointServiceAsyncClient() + + # Initialize request argument(s) + deployed_model = aiplatform_v1beta1.DeployedModel() + deployed_model.dedicated_resources.min_replica_count = 1803 + deployed_model.model = "model_value" + + request = aiplatform_v1beta1.MutateDeployedModelRequest( + endpoint="endpoint_value", + deployed_model=deployed_model, + ) + + # Make the request + operation = client.mutate_deployed_model(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_EndpointService_MutateDeployedModel_async] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_sync.py new file mode 100644 index 0000000000..de5baff8c0 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_sync.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for MutateDeployedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_EndpointService_MutateDeployedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +def sample_mutate_deployed_model(): + # Create a client + client = aiplatform_v1beta1.EndpointServiceClient() + + # Initialize request argument(s) + deployed_model = aiplatform_v1beta1.DeployedModel() + deployed_model.dedicated_resources.min_replica_count = 1803 + deployed_model.model = "model_value" + + request = aiplatform_v1beta1.MutateDeployedModelRequest( + endpoint="endpoint_value", + deployed_model=deployed_model, + ) + + # Make the request + operation = client.mutate_deployed_model(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_EndpointService_MutateDeployedModel_sync] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_async.py new file mode 100644 index 0000000000..f0badaf688 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetPublisherModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_ModelGardenService_GetPublisherModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +async def sample_get_publisher_model(): + # Create a client + client = aiplatform_v1beta1.ModelGardenServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.GetPublisherModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_publisher_model(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_ModelGardenService_GetPublisherModel_async] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_sync.py new file mode 100644 index 0000000000..13163228ed --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetPublisherModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_ModelGardenService_GetPublisherModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +def sample_get_publisher_model(): + # Create a client + client = aiplatform_v1beta1.ModelGardenServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.GetPublisherModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_publisher_model(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_ModelGardenService_GetPublisherModel_sync] diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 779f19b5fe..b1fc52433e 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.24.1" + "version": "1.25.0" }, "snippets": [ { @@ -2804,6 +2804,183 @@ ], "title": "aiplatform_v1_generated_endpoint_service_list_endpoints_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.EndpointServiceAsyncClient", + "shortName": "EndpointServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.EndpointServiceAsyncClient.mutate_deployed_model", + "method": { + "fullName": "google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel", + "service": { + "fullName": "google.cloud.aiplatform.v1.EndpointService", + "shortName": "EndpointService" + }, + "shortName": "MutateDeployedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.MutateDeployedModelRequest" + }, + { + "name": "endpoint", + "type": "str" + }, + { + "name": "deployed_model", + "type": "google.cloud.aiplatform_v1.types.DeployedModel" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation_async.AsyncOperation", + "shortName": "mutate_deployed_model" + }, + "description": "Sample for MutateDeployedModel", + "file": "aiplatform_v1_generated_endpoint_service_mutate_deployed_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_EndpointService_MutateDeployedModel_async", + "segments": [ + { + "end": 60, + "start": 27, + "type": "FULL" + }, + { + "end": 60, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 50, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 57, + "start": 51, + "type": "REQUEST_EXECUTION" + }, + { + "end": 61, + "start": 58, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_endpoint_service_mutate_deployed_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.EndpointServiceClient", + "shortName": "EndpointServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.EndpointServiceClient.mutate_deployed_model", + "method": { + "fullName": "google.cloud.aiplatform.v1.EndpointService.MutateDeployedModel", + "service": { + "fullName": "google.cloud.aiplatform.v1.EndpointService", + "shortName": "EndpointService" + }, + "shortName": "MutateDeployedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.MutateDeployedModelRequest" + }, + { + "name": "endpoint", + "type": "str" + }, + { + "name": "deployed_model", + "type": "google.cloud.aiplatform_v1.types.DeployedModel" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation.Operation", + "shortName": "mutate_deployed_model" + }, + "description": "Sample for MutateDeployedModel", + "file": "aiplatform_v1_generated_endpoint_service_mutate_deployed_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_EndpointService_MutateDeployedModel_sync", + "segments": [ + { + "end": 60, + "start": 27, + "type": "FULL" + }, + { + "end": 60, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 50, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 57, + "start": 51, + "type": "REQUEST_EXECUTION" + }, + { + "end": 61, + "start": 58, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_endpoint_service_mutate_deployed_model_sync.py" + }, { "canonical": true, "clientMethod": { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index c6fc29e8c8..617490aba3 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.24.1" + "version": "1.25.0" }, "snippets": [ { @@ -3625,6 +3625,183 @@ ], "title": "aiplatform_v1beta1_generated_endpoint_service_list_endpoints_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.EndpointServiceAsyncClient", + "shortName": "EndpointServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.EndpointServiceAsyncClient.mutate_deployed_model", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.EndpointService", + "shortName": "EndpointService" + }, + "shortName": "MutateDeployedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.MutateDeployedModelRequest" + }, + { + "name": "endpoint", + "type": "str" + }, + { + "name": "deployed_model", + "type": "google.cloud.aiplatform_v1beta1.types.DeployedModel" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation_async.AsyncOperation", + "shortName": "mutate_deployed_model" + }, + "description": "Sample for MutateDeployedModel", + "file": "aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_EndpointService_MutateDeployedModel_async", + "segments": [ + { + "end": 60, + "start": 27, + "type": "FULL" + }, + { + "end": 60, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 50, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 57, + "start": 51, + "type": "REQUEST_EXECUTION" + }, + { + "end": 61, + "start": 58, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.EndpointServiceClient", + "shortName": "EndpointServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.EndpointServiceClient.mutate_deployed_model", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.EndpointService.MutateDeployedModel", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.EndpointService", + "shortName": "EndpointService" + }, + "shortName": "MutateDeployedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.MutateDeployedModelRequest" + }, + { + "name": "endpoint", + "type": "str" + }, + { + "name": "deployed_model", + "type": "google.cloud.aiplatform_v1beta1.types.DeployedModel" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation.Operation", + "shortName": "mutate_deployed_model" + }, + "description": "Sample for MutateDeployedModel", + "file": "aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_EndpointService_MutateDeployedModel_sync", + "segments": [ + { + "end": 60, + "start": 27, + "type": "FULL" + }, + { + "end": 60, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 50, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 57, + "start": 51, + "type": "REQUEST_EXECUTION" + }, + { + "end": 61, + "start": 58, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_endpoint_service_mutate_deployed_model_sync.py" + }, { "canonical": true, "clientMethod": { @@ -21991,6 +22168,167 @@ ], "title": "aiplatform_v1beta1_generated_migration_service_search_migratable_resources_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.ModelGardenServiceAsyncClient", + "shortName": "ModelGardenServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.ModelGardenServiceAsyncClient.get_publisher_model", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.ModelGardenService.GetPublisherModel", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.ModelGardenService", + "shortName": "ModelGardenService" + }, + "shortName": "GetPublisherModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.GetPublisherModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1beta1.types.PublisherModel", + "shortName": "get_publisher_model" + }, + "description": "Sample for GetPublisherModel", + "file": "aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_ModelGardenService_GetPublisherModel_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.ModelGardenServiceClient", + "shortName": "ModelGardenServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.ModelGardenServiceClient.get_publisher_model", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.ModelGardenService.GetPublisherModel", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.ModelGardenService", + "shortName": "ModelGardenService" + }, + "shortName": "GetPublisherModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.GetPublisherModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1beta1.types.PublisherModel", + "shortName": "get_publisher_model" + }, + "description": "Sample for GetPublisherModel", + "file": "aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_ModelGardenService_GetPublisherModel_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_model_garden_service_get_publisher_model_sync.py" + }, { "canonical": true, "clientMethod": { @@ -27939,6 +28277,10 @@ "name": "name", "type": "str" }, + { + "name": "catch_up", + "type": "bool" + }, { "name": "retry", "type": "google.api_core.retry.Retry" @@ -28016,6 +28358,10 @@ "name": "name", "type": "str" }, + { + "name": "catch_up", + "type": "bool" + }, { "name": "retry", "type": "google.api_core.retry.Retry" diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 0b2df17f18..d11ea9ea23 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -451,6 +451,12 @@ def mock_tensorboard(): yield mock +@pytest.fixture +def mock_TensorBoard_uploaderTracker(): + mock = MagicMock(aiplatform.uploader_tracker) + yield mock + + @pytest.fixture def mock_create_tensorboard(mock_tensorboard): with patch.object(aiplatform.Tensorboard, "create") as mock: @@ -458,6 +464,29 @@ def mock_create_tensorboard(mock_tensorboard): yield mock +@pytest.fixture +def mock_tensorboard_uploader_onetime(): + with patch.object(aiplatform, "upload_tb_log") as mock_tensorboard_uploader_onetime: + mock_tensorboard_uploader_onetime.return_value = None + yield mock_tensorboard_uploader_onetime + + +@pytest.fixture +def mock_tensorboard_uploader_start(): + with patch.object( + aiplatform, "start_upload_tb_log" + ) as mock_tensorboard_uploader_start: + mock_tensorboard_uploader_start.return_value = None + yield mock_tensorboard_uploader_start + + +@pytest.fixture +def mock_tensorboard_uploader_end(): + with patch.object(aiplatform, "end_upload_tb_log") as mock_tensorboard_uploader_end: + mock_tensorboard_uploader_end.return_value = None + yield mock_tensorboard_uploader_end + + """ ---------------------------------------------------------------------------- Endpoint Fixtures diff --git a/samples/model-builder/experiment_tracking/log_metrics_sample.py b/samples/model-builder/experiment_tracking/log_metrics_sample.py index 5bff351610..172ab59e04 100644 --- a/samples/model-builder/experiment_tracking/log_metrics_sample.py +++ b/samples/model-builder/experiment_tracking/log_metrics_sample.py @@ -27,7 +27,7 @@ def log_metrics_sample( ): aiplatform.init(experiment=experiment_name, project=project, location=location) - aiplatform.start_run(run=run_name, resume=True) + aiplatform.start_run(run=run_name) aiplatform.log_metrics(metrics) diff --git a/samples/model-builder/experiment_tracking/upload_tensorboard_log_continuously_sample.py b/samples/model-builder/experiment_tracking/upload_tensorboard_log_continuously_sample.py new file mode 100644 index 0000000000..47769d7027 --- /dev/null +++ b/samples/model-builder/experiment_tracking/upload_tensorboard_log_continuously_sample.py @@ -0,0 +1,46 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_upload_tensorboard_log_sample] +def upload_tensorboard_log_continuously_sample( + tensorboard_experiment_name: str, + logdir: str, + tensorboard_id: str, + project: str, + location: str, + experiment_display_name: Optional[str] = None, + run_name_prefix: Optional[str] = None, + description: Optional[str] = None, +) -> None: + + aiplatform.init(project=project, location=location) + + # Continuous monitoring + aiplatform.start_upload_tb_log( + tensorboard_id=tensorboard_id, + tensorboard_experiment_name=tensorboard_experiment_name, + logdir=logdir, + experiment_display_name=experiment_display_name, + run_name_prefix=run_name_prefix, + description=description, + ) + aiplatform.end_upload_tb_log() + + +# [END aiplatform_sdk_upload_tensorboard_log_sample] diff --git a/samples/model-builder/experiment_tracking/upload_tensorboard_log_continuously_sample_test.py b/samples/model-builder/experiment_tracking/upload_tensorboard_log_continuously_sample_test.py new file mode 100644 index 0000000000..b88f55b878 --- /dev/null +++ b/samples/model-builder/experiment_tracking/upload_tensorboard_log_continuously_sample_test.py @@ -0,0 +1,50 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from experiment_tracking import upload_tensorboard_log_continuously_sample +import test_constants as constants + + +def test_upload_tensorboard_log_continuously_sample( + mock_sdk_init, + mock_tensorboard_uploader_start, + mock_tensorboard_uploader_end, +): + upload_tensorboard_log_continuously_sample.upload_tensorboard_log_continuously_sample( + project=constants.PROJECT, + location=constants.LOCATION, + logdir=constants.TENSORBOARD_LOG_DIR, + tensorboard_id=constants.TENSORBOARD_ID, + tensorboard_experiment_name=constants.TENSORBOARD_EXPERIMENT_NAME, + experiment_display_name=constants.EXPERIMENT_NAME, + run_name_prefix=constants.EXPERIMENT_RUN_NAME, + description=constants.DESCRIPTION, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, + location=constants.LOCATION, + ) + + mock_tensorboard_uploader_start.assert_called_once_with( + logdir=constants.TENSORBOARD_LOG_DIR, + tensorboard_id=constants.TENSORBOARD_ID, + tensorboard_experiment_name=constants.TENSORBOARD_EXPERIMENT_NAME, + experiment_display_name=constants.EXPERIMENT_NAME, + run_name_prefix=constants.EXPERIMENT_RUN_NAME, + description=constants.DESCRIPTION, + ) + + mock_tensorboard_uploader_end.assert_called_once_with() diff --git a/samples/model-builder/experiment_tracking/upload_tensorboard_log_one_time_sample.py b/samples/model-builder/experiment_tracking/upload_tensorboard_log_one_time_sample.py new file mode 100644 index 0000000000..18e55e8967 --- /dev/null +++ b/samples/model-builder/experiment_tracking/upload_tensorboard_log_one_time_sample.py @@ -0,0 +1,46 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_upload_tensorboard_log_one_time_sample] +def upload_tensorboard_log_one_time_sample( + tensorboard_experiment_name: str, + logdir: str, + tensorboard_id: str, + project: str, + location: str, + experiment_display_name: Optional[str] = None, + run_name_prefix: Optional[str] = None, + description: Optional[str] = None, + verbosity: Optional[int] = 1, +) -> None: + + aiplatform.init(project=project, location=location) + + # one time upload + aiplatform.upload_tb_log( + tensorboard_id=tensorboard_id, + tensorboard_experiment_name=tensorboard_experiment_name, + logdir=logdir, + experiment_display_name=experiment_display_name, + run_name_prefix=run_name_prefix, + description=description, + ) + + +# [END aiplatform_sdk_upload_tensorboard_log_one_time_sample] diff --git a/samples/model-builder/experiment_tracking/upload_tensorboard_log_one_time_sample_test.py b/samples/model-builder/experiment_tracking/upload_tensorboard_log_one_time_sample_test.py new file mode 100644 index 0000000000..4eaff7d841 --- /dev/null +++ b/samples/model-builder/experiment_tracking/upload_tensorboard_log_one_time_sample_test.py @@ -0,0 +1,47 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from experiment_tracking import upload_tensorboard_log_one_time_sample +import test_constants as constants + + +def test_upload_tensorboard_log_one_time_sample( + mock_sdk_init, + mock_tensorboard_uploader_onetime, +): + upload_tensorboard_log_one_time_sample.upload_tensorboard_log_one_time_sample( + project=constants.PROJECT, + location=constants.LOCATION, + logdir=constants.TENSORBOARD_LOG_DIR, + tensorboard_id=constants.TENSORBOARD_ID, + tensorboard_experiment_name=constants.TENSORBOARD_EXPERIMENT_NAME, + experiment_display_name=constants.EXPERIMENT_NAME, + run_name_prefix=constants.EXPERIMENT_RUN_NAME, + description=constants.DESCRIPTION, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, + location=constants.LOCATION, + ) + + mock_tensorboard_uploader_onetime.assert_called_once_with( + logdir=constants.TENSORBOARD_LOG_DIR, + tensorboard_id=constants.TENSORBOARD_ID, + tensorboard_experiment_name=constants.TENSORBOARD_EXPERIMENT_NAME, + experiment_display_name=constants.EXPERIMENT_NAME, + run_name_prefix=constants.EXPERIMENT_RUN_NAME, + description=constants.DESCRIPTION, + ) diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index d84df4f752..c48d19a505 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -272,10 +272,6 @@ VALIDATION_OPTIONS = "fail-pipeline" PREDEFINED_SPLIT_COLUMN_NAME = "predefined" -TENSORBOARD_NAME = ( - f"projects/{PROJECT}/locations/{LOCATION}/tensorboards/my-tensorboard" -) - SCHEMA_TITLE = "system.Schema" SCHEMA_VERSION = "0.0.1" METADATA = {} @@ -332,3 +328,11 @@ IS_DEFAULT_VERSION = False VERSION_ALIASES = ["test-version-alias"] VERSION_DESCRIPTION = "test-version-description" + +# TensorBoard +TENSORBOARD_LOG_DIR = "gs://fake-dir" +TENSORBOARD_ID = "8888888888888888888" +TENSORBOARD_NAME = ( + f"projects/{PROJECT}/locations/{LOCATION}/tensorboards/my-tensorboard" +) +TENSORBOARD_EXPERIMENT_NAME = "my-tensorboard-experiment" diff --git a/setup.py b/setup.py index 487071c3b4..93432eec0c 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ packages = [ package for package in setuptools.PEP420PackageFinder.find() - if package.startswith("google") + if package.startswith("google") or package.startswith("vertexai") ] tensorboard_extra_require = ["tensorflow >=2.3.0, <3.0.0dev"] diff --git a/tests/system/aiplatform/test_autologging.py b/tests/system/aiplatform/test_autologging.py index bae63f8440..7c1bf49993 100644 --- a/tests/system/aiplatform/test_autologging.py +++ b/tests/system/aiplatform/test_autologging.py @@ -130,19 +130,12 @@ def setup_class(cls): cls._experiment_autocreate_tf = cls._make_display_name("")[:64] cls._experiment_manual_scikit = cls._make_display_name("")[:64] cls._experiment_nested_run = cls._make_display_name("")[:64] - cls._experiment_disable_test = cls._make_display_name("")[:64] cls._backing_tensorboard = aiplatform.Tensorboard.create( project=e2e_base._PROJECT, location=e2e_base._LOCATION, display_name=cls._make_display_name("")[:64], ) - cls._experiment_enable_name = cls._make_display_name("")[:64] - cls._experiment_enable_test = aiplatform.Experiment.get_or_create( - experiment_name=cls._experiment_enable_name - ) - cls._experiment_enable_test.assign_backing_tensorboard(cls._backing_tensorboard) - def test_autologging_with_autorun_creation(self, shared_state): aiplatform.init( @@ -279,34 +272,3 @@ def test_autologging_nested_run_model(self, shared_state, caplog): assert "This model creates nested runs." in caplog.text caplog.clear() - - def test_autologging_enable_disable_check(self, shared_state, caplog): - - caplog.set_level(logging.INFO) - - # first enable autologging with provided tb-backed experiment - aiplatform.init( - project=e2e_base._PROJECT, - location=e2e_base._LOCATION, - experiment=self._experiment_enable_name, - ) - - shared_state["resources"].append( - aiplatform.metadata.metadata._experiment_tracker.experiment - ) - - aiplatform.autolog() - - assert aiplatform.utils.autologging_utils._is_autologging_enabled() - - aiplatform.metadata.metadata._experiment_tracker._global_tensorboard = None - - # re-initializing without tb-backed experiment should disable autologging - aiplatform.init( - project=e2e_base._PROJECT, - location=e2e_base._LOCATION, - experiment=self._experiment_disable_test, - ) - - assert "Disabling" in caplog.text - caplog.clear() diff --git a/tests/system/aiplatform/test_e2e_forecasting.py b/tests/system/aiplatform/test_e2e_forecasting.py index 938d0e27b5..45b16a015f 100644 --- a/tests/system/aiplatform/test_e2e_forecasting.py +++ b/tests/system/aiplatform/test_e2e_forecasting.py @@ -18,7 +18,7 @@ from google.cloud import aiplatform from google.cloud.aiplatform import training_jobs -# from google.cloud.aiplatform.compat.types import job_state +from google.cloud.aiplatform.compat.types import job_state from google.cloud.aiplatform.compat.types import pipeline_state import pytest from tests.system.aiplatform import e2e_base @@ -103,24 +103,23 @@ def test_end_to_end_forecasting(self, shared_state, training_job): ) resources.append(model) - # TODO(b/275569167) Uncomment this when the bug is fixed - # batch_prediction_job = model.batch_predict( - # job_display_name=self._make_display_name("forecasting-liquor-model"), - # instances_format="bigquery", - # predictions_format="csv", - # machine_type="n1-standard-4", - # bigquery_source=_PREDICTION_DATASET_BQ_PATH, - # gcs_destination_prefix=( - # f'gs://{shared_state["staging_bucket_name"]}/bp_results/' - # ), - # sync=False, - # ) - # resources.append(batch_prediction_job) + batch_prediction_job = model.batch_predict( + job_display_name=self._make_display_name("forecasting-liquor-model"), + instances_format="bigquery", + predictions_format="csv", + machine_type="n1-standard-4", + bigquery_source=_PREDICTION_DATASET_BQ_PATH, + gcs_destination_prefix=( + f'gs://{shared_state["staging_bucket_name"]}/bp_results/' + ), + sync=False, + ) + resources.append(batch_prediction_job) - # batch_prediction_job.wait() + batch_prediction_job.wait() model.wait() assert job.state == pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED - # assert batch_prediction_job.state == job_state.JobState.JOB_STATE_SUCCEEDED + assert batch_prediction_job.state == job_state.JobState.JOB_STATE_SUCCEEDED finally: for resource in resources: resource.delete() diff --git a/tests/system/aiplatform/test_e2e_tabular.py b/tests/system/aiplatform/test_e2e_tabular.py index 6cdedad5a2..f6700fa099 100644 --- a/tests/system/aiplatform/test_e2e_tabular.py +++ b/tests/system/aiplatform/test_e2e_tabular.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2021 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,20 +16,21 @@ # import os -from urllib import request import pytest +from google.cloud import storage + from google.cloud import aiplatform from google.cloud.aiplatform.compat.types import ( - # job_state as gca_job_state, + job_state as gca_job_state, pipeline_state as gca_pipeline_state, ) from tests.system.aiplatform import e2e_base -_BLOB_PATH = "california-housing-data.csv" -_DATASET_SRC = "https://dl.google.com/mlcc/mledu-datasets/california_housing_train.csv" +_DATASET_TRAINING_SRC = "gs://cloud-samples-data-us-central1/vertex-ai/structured_data/california_housing/california-housing-data.csv" +_DATASET_BATCH_PREDICT_SRC = "gs://cloud-samples-data-us-central1/vertex-ai/batch-prediction/california_housing_batch_predict.jsonl" _DIR_NAME = os.path.dirname(os.path.abspath(__file__)) _LOCAL_TRAINING_SCRIPT_PATH = os.path.join( _DIR_NAME, "test_resources/california_housing_training_script.py" @@ -58,16 +59,6 @@ class TestEndToEndTabular(e2e_base.TestEndToEnd): def test_end_to_end_tabular(self, shared_state): """Build dataset, train a custom and AutoML model, deploy, and get predictions""" - assert shared_state["bucket"] - bucket = shared_state["bucket"] - - blob = bucket.blob(_BLOB_PATH) - - # Download the CSV file into memory and save it directory to staging bucket - with request.urlopen(_DATASET_SRC) as response: - data = response.read() - blob.upload_from_string(data) - # Collection of resources generated by this test, to be deleted during teardown shared_state["resources"] = [] @@ -79,11 +70,9 @@ def test_end_to_end_tabular(self, shared_state): # Create and import to single managed dataset for both training jobs - dataset_gcs_source = f'gs://{shared_state["staging_bucket_name"]}/{_BLOB_PATH}' - ds = aiplatform.TabularDataset.create( display_name=self._make_display_name("dataset"), - gcs_source=[dataset_gcs_source], + gcs_source=[_DATASET_TRAINING_SRC], sync=False, create_request_timeout=180.0, ) @@ -135,17 +124,16 @@ def test_end_to_end_tabular(self, shared_state): automl_endpoint = automl_model.deploy(machine_type="n1-standard-4", sync=False) shared_state["resources"].extend([automl_endpoint, custom_endpoint]) - # TODO(b/275569167) Uncomment this after timeout issue is resolved - # custom_batch_prediction_job = custom_model.batch_predict( - # job_display_name=self._make_display_name("automl-housing-model"), - # instances_format="csv", - # machine_type="n1-standard-4", - # gcs_source=dataset_gcs_source, - # gcs_destination_prefix=f'gs://{shared_state["staging_bucket_name"]}/bp_results/', - # sync=False, - # ) + custom_batch_prediction_job = custom_model.batch_predict( + job_display_name=self._make_display_name("custom-housing-model"), + instances_format="jsonl", + machine_type="n1-standard-4", + gcs_source=_DATASET_BATCH_PREDICT_SRC, + gcs_destination_prefix=f'gs://{shared_state["staging_bucket_name"]}/bp_results/', + sync=False, + ) - # shared_state["resources"].append(custom_batch_prediction_job) + shared_state["resources"].append(custom_batch_prediction_job) in_progress_done_check = custom_job.done() custom_job.wait_for_resource_creation() @@ -171,7 +159,7 @@ def test_end_to_end_tabular(self, shared_state): custom_prediction = custom_endpoint.predict([_INSTANCE], timeout=180.0) - # custom_batch_prediction_job.wait() + custom_batch_prediction_job.wait() automl_endpoint.wait() automl_prediction = automl_endpoint.predict( @@ -194,10 +182,25 @@ def test_end_to_end_tabular(self, shared_state): automl_job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED ) - # assert ( - # custom_batch_prediction_job.state - # == gca_job_state.JobState.JOB_STATE_SUCCEEDED - # ) + assert ( + custom_batch_prediction_job.state + == gca_job_state.JobState.JOB_STATE_SUCCEEDED + ) + + # Ensure batch prediction errors output file is empty + batch_predict_gcs_output_path = ( + custom_batch_prediction_job.output_info.gcs_output_directory + ) + client = storage.Client() + + for blob in client.list_blobs( + bucket_or_name=shared_state["staging_bucket_name"], + prefix=f"bp_results/{batch_predict_gcs_output_path.split('/')[-1]}", + ): + # There are always 2 files in this output path: 1 with errors, 1 with predictions + if "errors" in blob.name: + error_output_filestr = blob.download_as_string().decode() + assert not error_output_filestr # Ensure a single prediction was returned assert len(custom_prediction.predictions) == 1 diff --git a/tests/system/aiplatform/test_featurestore.py b/tests/system/aiplatform/test_featurestore.py index 7a05f3c15d..cf875bac85 100644 --- a/tests/system/aiplatform/test_featurestore.py +++ b/tests/system/aiplatform/test_featurestore.py @@ -36,6 +36,7 @@ _TEST_FEATURESTORE_ID = "movie_prediction" _TEST_USER_ENTITY_TYPE_ID = "users" _TEST_MOVIE_ENTITY_TYPE_ID = "movies" +_TEST_MOVIE_ENTITY_TYPE_UPDATE_LABELS = {"my_key_update": "my_value_update"} _TEST_USER_AGE_FEATURE_ID = "age" _TEST_USER_GENDER_FEATURE_ID = "gender" @@ -129,6 +130,15 @@ def test_create_get_list_entity_types(self, shared_state): entity_type.resource_name for entity_type in list_entity_types ] + # Update information about the movie entity type. + assert movie_entity_type.labels != _TEST_MOVIE_ENTITY_TYPE_UPDATE_LABELS + + movie_entity_type.update( + labels=_TEST_MOVIE_ENTITY_TYPE_UPDATE_LABELS, + ) + + assert movie_entity_type.labels == _TEST_MOVIE_ENTITY_TYPE_UPDATE_LABELS + def test_create_get_list_features(self, shared_state): assert shared_state["user_entity_type"] diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py new file mode 100644 index 0000000000..f91ac2efad --- /dev/null +++ b/tests/system/aiplatform/test_language_models.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pylint: disable=protected-access, g-multiple-import + +from google.cloud import aiplatform +from tests.system.aiplatform import e2e_base +from vertexai.preview.language_models import ( + ChatModel, + InputOutputTextPair, + TextGenerationModel, + TextEmbeddingModel, +) + + +class TestLanguageModels(e2e_base.TestEndToEnd): + """System tests for language models.""" + + _temp_prefix = "temp_language_models_test_" + + def test_text_generation(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = TextGenerationModel.from_pretrained("google/text-bison@001") + + assert model.predict( + "What is the best recipe for banana bread? Recipe:", + max_output_tokens=128, + temperature=0, + top_p=1, + top_k=5, + ).text + + def test_chat_on_chat_model(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + chat_model = ChatModel.from_pretrained("google/chat-bison@001") + chat = chat_model.start_chat( + context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.", + examples=[ + InputOutputTextPair( + input_text="Who do you work for?", + output_text="I work for Ned.", + ), + InputOutputTextPair( + input_text="What do I like?", + output_text="Ned likes watching movies.", + ), + ], + temperature=0.0, + ) + + assert chat.send_message("Are my favorite movies based on a book series?").text + assert len(chat._history) == 1 + assert chat.send_message( + "When where these books published?", + temperature=0.1, + ).text + assert len(chat._history) == 2 + + def test_text_embedding(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = TextEmbeddingModel.from_pretrained("google/textembedding-gecko@001") + embeddings = model.get_embeddings(["What is life?"]) + assert embeddings + for embedding in embeddings: + vector = embedding.values + assert len(vector) == 768 diff --git a/tests/unit/aiplatform/test_autologging.py b/tests/unit/aiplatform/test_autologging.py index 4705c95169..a0516a61e8 100644 --- a/tests/unit/aiplatform/test_autologging.py +++ b/tests/unit/aiplatform/test_autologging.py @@ -60,6 +60,8 @@ from google.cloud.aiplatform.compat.types import ( tensorboard as gca_tensorboard, ) +from google.cloud.aiplatform.metadata import metadata + import test_tensorboard import test_metadata @@ -454,6 +456,15 @@ def update_context_mock(): yield update_context_mock +@pytest.fixture +def get_or_create_default_tb_none_mock(): + with patch.object( + metadata, "_get_or_create_default_tensorboard" + ) as get_or_create_default_tb_none_mock: + get_or_create_default_tb_none_mock.return_value = None + yield get_or_create_default_tb_none_mock + + _TEST_EXPERIMENT_RUN_CONTEXT_NAME = f"{_TEST_PARENT}/contexts/{_TEST_EXECUTION_ID}" _TEST_OTHER_EXPERIMENT_RUN_CONTEXT_NAME = ( f"{_TEST_PARENT}/contexts/{_TEST_OTHER_EXECUTION_ID}" @@ -702,6 +713,7 @@ def test_autologging_raises_if_experiment_not_set( "get_experiment_mock_without_tensorboard", "get_metadata_store_mock", "update_context_mock", + "get_or_create_default_tb_none_mock", ) def test_autologging_raises_if_experiment_tensorboard_not_set( self, diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py index b605996704..19e3dfde41 100644 --- a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -1294,3 +1294,41 @@ def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference training_pipeline=true_training_pipeline, timeout=None, ) + + def test_automl_forecasting_with_no_transformations( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_time_series, + mock_model_service_get, + ): + aiplatform.init(project=_TEST_PROJECT) + job = training_jobs.AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + ) + mock_dataset_time_series.column_names = [ + "a", + "b", + _TEST_TRAINING_TARGET_COLUMN, + ] + job.run( + dataset=mock_dataset_time_series, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + ) + assert job._column_transformations == [ + {"auto": {"column_name": "a"}}, + {"auto": {"column_name": "b"}}, + ] diff --git a/tests/unit/aiplatform/test_featurestores.py b/tests/unit/aiplatform/test_featurestores.py index 25a939c795..3eccdb401f 100644 --- a/tests/unit/aiplatform/test_featurestores.py +++ b/tests/unit/aiplatform/test_featurestores.py @@ -591,8 +591,10 @@ def update_entity_type_mock(): with patch.object( featurestore_service_client.FeaturestoreServiceClient, "update_entity_type" ) as update_entity_type_mock: - update_entity_type_lro_mock = mock.Mock(operation.Operation) - update_entity_type_mock.return_value = update_entity_type_lro_mock + update_entity_type_mock.return_value = gca_entity_type.EntityType( + name=_TEST_ENTITY_TYPE_NAME, + labels=_TEST_LABELS_UPDATE, + ) yield update_entity_type_mock @@ -2104,6 +2106,8 @@ def test_update_entity_type(self, update_entity_type_mock): timeout=None, ) + assert my_entity_type.labels == _TEST_LABELS_UPDATE + @pytest.mark.parametrize( "featurestore_name", [_TEST_FEATURESTORE_NAME, _TEST_FEATURESTORE_ID] ) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py new file mode 100644 index 0000000000..71b27465fc --- /dev/null +++ b/tests/unit/aiplatform/test_language_models.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pylint: disable=protected-access,bad-continuation + +import pytest + +from importlib import reload +from unittest import mock + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer + +from google.cloud.aiplatform.compat.services import ( + model_garden_service_client_v1beta1, +) +from google.cloud.aiplatform.compat.services import prediction_service_client +from google.cloud.aiplatform.compat.types import ( + prediction_service as gca_prediction_service, +) +from google.cloud.aiplatform_v1beta1.types import ( + publisher_model as gca_publisher_model, +) + +from vertexai.preview import language_models + + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" + +_TEXT_BISON_PUBLISHER_MODEL_DICT = { + "name": "publishers/google/models/text-bison", + "version_id": "001", + "open_source_category": "PROPRIETARY", + "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/text-bison@001", + "predict_schemata": { + "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml", + "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_generation_1.0.0.yaml", + "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/text_generation_1.0.0.yaml", + }, +} + +_CHAT_BISON_PUBLISHER_MODEL_DICT = { + "name": "publishers/google/models/chat-bison", + "version_id": "001", + "open_source_category": "PROPRIETARY", + "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001", + "predict_schemata": { + "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml", + "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/chat_generation_1.0.0.yaml", + "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/chat_generation_1.0.0.yaml", + }, +} + +_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT = { + "name": "publishers/google/models/textembedding-gecko", + "version_id": "001", + "open_source_category": "PROPRIETARY", + "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001", + "predict_schemata": { + "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml", + "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_generation_1.0.0.yaml", + "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/text_embedding_1.0.0.yaml", + }, +} + +_TEST_TEXT_GENERATION_PREDICTION = { + "safetyAttributes": { + "categories": ["Violent"], + "blocked": False, + "scores": [0.10000000149011612], + }, + "content": """ +Ingredients: +* 3 cups all-purpose flour + +Instructions: +1. Preheat oven to 350 degrees F (175 degrees C).""", +} + +_TEST_CHAT_GENERATION_PREDICTION1 = { + "safetyAttributes": { + "scores": [], + "blocked": False, + "categories": [], + }, + "candidates": [ + { + "author": "1", + "content": "Chat response 1", + } + ], +} +_TEST_CHAT_GENERATION_PREDICTION2 = { + "safetyAttributes": { + "scores": [], + "blocked": False, + "categories": [], + }, + "candidates": [ + { + "author": "1", + "content": "Chat response 2", + } + ], +} + +_TEXT_EMBEDDING_VECTOR_LENGTH = 768 +_TEST_TEXT_EMBEDDING_PREDICTION = { + "embeddings": { + "values": list([1.0] * _TEXT_EMBEDDING_VECTOR_LENGTH), + } +} + + +@pytest.mark.usefixtures("google_auth_mock") +class TestLanguageModels: + """Unit tests for the language models.""" + + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_text_generation(self): + """Tests the text generation model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client_v1beta1.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = language_models.TextGenerationModel.from_pretrained( + "google/text-bison@001" + ) + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY + ) + + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + response = model.predict( + "What is the best recipe for banana bread? Recipe:", + max_output_tokens=128, + temperature=0, + top_p=1, + top_k=5, + ) + + assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + + def test_chat(self): + """Tests the chat generation model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client_v1beta1.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CHAT_BISON_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = language_models.ChatModel.from_pretrained("google/chat-bison@001") + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY + ) + + chat = model.start_chat( + context=""" + My name is Ned. + You are my personal assistant. + My favorite movies are Lord of the Rings and Hobbit. + """, + examples=[ + language_models.InputOutputTextPair( + input_text="Who do you work for?", + output_text="I work for Ned.", + ), + language_models.InputOutputTextPair( + input_text="What do I like?", + output_text="Ned likes watching movies.", + ), + ], + temperature=0.0, + ) + + gca_predict_response1 = gca_prediction_service.PredictResponse() + gca_predict_response1.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response1, + ): + response = chat.send_message( + "Are my favorite movies based on a book series?" + ) + assert ( + response.text + == _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0]["content"] + ) + assert len(chat._history) == 1 + + gca_predict_response2 = gca_prediction_service.PredictResponse() + gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response2, + ): + response = chat.send_message( + "When where these books published?", + temperature=0.1, + ) + assert ( + response.text + == _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0]["content"] + ) + assert len(chat._history) == 2 + + def test_text_embedding(self): + """Tests the text embedding model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client_v1beta1.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = language_models.TextEmbeddingModel.from_pretrained( + "google/textembedding-gecko@001" + ) + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/textembedding-gecko@001", + retry=base._DEFAULT_RETRY, + ) + + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + embeddings = model.get_embeddings(["What is life?"]) + assert embeddings + for embedding in embeddings: + vector = embedding.values + assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH + assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"] diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 9d176b96b5..3590c3e5e1 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -32,11 +32,13 @@ matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, index_endpoint as gca_index_endpoint, index as gca_index, + match_service_v1beta1 as gca_match_service_v1beta1, + index_v1beta1 as gca_index_v1beta1, ) - from google.cloud.aiplatform.compat.services import ( index_endpoint_service_client, index_service_client, + match_service_client_v1beta1, ) import constants as test_constants @@ -229,6 +231,7 @@ _TEST_FILTER = [ Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"]) ] +_TEST_IDS = ["123", "456", "789"] def uuid_mock(): @@ -476,6 +479,49 @@ def index_endpoint_match_queries_mock(): yield index_endpoint_match_queries_mock +@pytest.fixture +def index_public_endpoint_match_queries_mock(): + with patch.object( + match_service_client_v1beta1.MatchServiceClient, "find_neighbors" + ) as index_public_endpoint_match_queries_mock: + index_public_endpoint_match_queries_mock.return_value = ( + gca_match_service_v1beta1.FindNeighborsResponse( + nearest_neighbors=[ + gca_match_service_v1beta1.FindNeighborsResponse.NearestNeighbors( + id="1", + neighbors=[ + gca_match_service_v1beta1.FindNeighborsResponse.Neighbor( + datapoint=gca_index_v1beta1.IndexDatapoint( + datapoint_id="1" + ), + distance=0.1, + ) + ], + ) + ] + ) + ) + yield index_public_endpoint_match_queries_mock + + +@pytest.fixture +def index_public_endpoint_read_index_datapoints_mock(): + with patch.object( + match_service_client_v1beta1.MatchServiceClient, "read_index_datapoints" + ) as index_public_endpoint_read_index_datapoints_mock: + index_public_endpoint_read_index_datapoints_mock.return_value = ( + gca_match_service_v1beta1.ReadIndexDatapointsResponse( + datapoints=[ + gca_index_v1beta1.IndexDatapoint( + datapoint_id="1", + feature_vector=[0, 1, 2, 3], + ) + ] + ) + ) + yield index_public_endpoint_read_index_datapoints_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestMatchingEngineIndexEndpoint: def setup_method(self): @@ -845,3 +891,71 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock): ) index_endpoint_match_queries_mock.assert_called_with(batch_request) + + @pytest.mark.usefixtures("get_index_public_endpoint_mock") + def test_index_public_endpoint_match_queries( + self, index_public_endpoint_match_queries_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_pubic_index_endpoint.find_neighbors( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=_TEST_QUERIES, + num_neighbors=_TEST_NUM_NEIGHBOURS, + filter=_TEST_FILTER, + ) + + find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest( + index_endpoint=my_pubic_index_endpoint.resource_name, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=[ + gca_match_service_v1beta1.FindNeighborsRequest.Query( + neighbor_count=_TEST_NUM_NEIGHBOURS, + datapoint=gca_index_v1beta1.IndexDatapoint( + feature_vector=_TEST_QUERIES[0], + restricts=[ + gca_index_v1beta1.IndexDatapoint.Restriction( + namespace="class", + allow_list=["token_1"], + deny_list=["token_2"], + ) + ], + ), + ) + ], + ) + + index_public_endpoint_match_queries_mock.assert_called_with( + find_neighbors_request + ) + + @pytest.mark.usefixtures("get_index_public_endpoint_mock") + def test_index_public_endpoint_read_index_datapoints( + self, index_public_endpoint_read_index_datapoints_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_pubic_index_endpoint.read_index_datapoints( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + ids=_TEST_IDS, + ) + + read_index_datapoints_request = ( + gca_match_service_v1beta1.ReadIndexDatapointsRequest( + index_endpoint=my_pubic_index_endpoint.resource_name, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + ids=_TEST_IDS, + ) + ) + + index_public_endpoint_read_index_datapoints_mock.assert_called_with( + read_index_datapoints_request + ) diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py index a9ef13363a..545a4c016f 100644 --- a/tests/unit/aiplatform/test_metadata.py +++ b/tests/unit/aiplatform/test_metadata.py @@ -53,6 +53,9 @@ from google.cloud.aiplatform.compat.types import ( tensorboard_data as gca_tensorboard_data, ) +from google.cloud.aiplatform.compat.types import ( + tensorboard as gca_tensorboard, +) from google.cloud.aiplatform.compat.types import ( tensorboard_experiment as gca_tensorboard_experiment, ) @@ -63,10 +66,13 @@ tensorboard_time_series as gca_tensorboard_time_series, ) from google.cloud.aiplatform.metadata import constants +from google.cloud.aiplatform.metadata import experiment_resources from google.cloud.aiplatform.metadata import experiment_run_resource from google.cloud.aiplatform.metadata import metadata from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform.metadata import utils as metadata_utils +from google.cloud.aiplatform.tensorboard import tensorboard_resource + from google.cloud.aiplatform import utils import constants as test_constants @@ -147,6 +153,13 @@ # schema _TEST_WRONG_SCHEMA_TITLE = "system.WrongSchema" +# tensorboard +_TEST_DEFAULT_TENSORBOARD_NAME = "test-tensorboard-default-name" +_TEST_DEFAULT_TENSORBOARD_GCA = gca_tensorboard.Tensorboard( + name=_TEST_DEFAULT_TENSORBOARD_NAME, + is_default=True, +) + @pytest.fixture def get_metadata_store_mock(): @@ -366,6 +379,53 @@ def get_tensorboard_run_not_found_mock(): yield get_tensorboard_run_mock +@pytest.fixture +def list_default_tensorboard_mock(): + with patch.object( + TensorboardServiceClient, "list_tensorboards" + ) as list_default_tensorboard_mock: + list_default_tensorboard_mock.side_effect = [ + [_TEST_DEFAULT_TENSORBOARD_GCA], + [_TEST_DEFAULT_TENSORBOARD_GCA], + ] + yield list_default_tensorboard_mock + + +@pytest.fixture +def list_default_tensorboard_empty_mock(): + with patch.object( + TensorboardServiceClient, "list_tensorboards" + ) as list_default_tensorboard_empty_mock: + list_default_tensorboard_empty_mock.return_value = [] + yield list_default_tensorboard_empty_mock + + +@pytest.fixture +def create_default_tensorboard_mock(): + with patch.object( + tensorboard_resource.Tensorboard, "create" + ) as create_default_tensorboard_mock: + create_default_tensorboard_mock.return_value = _TEST_DEFAULT_TENSORBOARD_GCA + yield create_default_tensorboard_mock + + +@pytest.fixture +def assign_backing_tensorboard_mock(): + with patch.object( + experiment_resources.Experiment, "assign_backing_tensorboard" + ) as assign_backing_tensorboard_mock: + yield assign_backing_tensorboard_mock + + +@pytest.fixture +def get_or_create_default_tb_none_mock(): + with patch.object( + metadata, "_get_or_create_default_tensorboard" + ) as get_or_create_default_tb_none_mock: + get_or_create_default_tb_none_mock.return_value = None + yield get_or_create_default_tb_none_mock + + @pytest.fixture def get_tensorboard_experiment_not_found_mock(): with patch.object( @@ -1025,6 +1085,7 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) + @pytest.mark.usefixtures("get_or_create_default_tb_none_mock") def test_init_experiment_with_existing_metadataStore_and_context( self, get_metadata_store_mock, get_experiment_run_run_mock ): @@ -1041,6 +1102,27 @@ def test_init_experiment_with_existing_metadataStore_and_context( name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY ) + @pytest.mark.usefixtures( + "get_metadata_store_mock", + "get_experiment_run_run_mock", + ) + def test_init_experiment_with_default_tensorboard( + self, list_default_tensorboard_mock, assign_backing_tensorboard_mock + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + experiment=_TEST_EXPERIMENT, + ) + + list_default_tensorboard_mock.assert_called_once_with( + request={ + "parent": f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + "filter": "is_default=true", + } + ) + assign_backing_tensorboard_mock.assert_called_once() + @pytest.mark.usefixtures("get_metadata_store_mock") def test_create_experiment(self, create_experiment_context_mock): exp = aiplatform.Experiment.create( @@ -1060,6 +1142,9 @@ def test_create_experiment(self, create_experiment_context_mock): assert exp._metadata_context.gca_resource == _TEST_EXPERIMENT_CONTEXT + @pytest.mark.usefixtures( + "get_or_create_default_tb_none_mock", + ) def test_init_experiment_with_credentials( self, get_metadata_store_mock, @@ -1116,6 +1201,7 @@ def test_init_and_get_then_create_metadata_store_with_credentials( assert store.api_client._transport._credentials == creds + @pytest.mark.usefixtures("get_or_create_default_tb_none_mock") def test_init_experiment_with_existing_description( self, get_metadata_store_mock, get_experiment_run_run_mock ): @@ -1133,7 +1219,11 @@ def test_init_experiment_with_existing_description( name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY ) - @pytest.mark.usefixtures("get_metadata_store_mock", "get_experiment_run_run_mock") + @pytest.mark.usefixtures( + "get_metadata_store_mock", + "get_experiment_run_run_mock", + "get_or_create_default_tb_none_mock", + ) def test_init_experiment_without_existing_description( self, update_context_mock, @@ -1161,6 +1251,7 @@ def test_init_experiment_without_existing_description( "get_experiment_run_mock", "update_experiment_run_context_to_running", "get_tensorboard_run_artifact_not_found_mock", + "get_or_create_default_tb_none_mock", ) def test_init_experiment_reset(self): aiplatform.init( @@ -1244,6 +1335,7 @@ def test_start_run_from_env_experiment( "get_metadata_store_mock", "get_experiment_run_mock", "get_tensorboard_run_artifact_not_found_mock", + "get_or_create_default_tb_none_mock", ) def test_init_experiment_run_from_env(self): os.environ["AIP_EXPERIMENT_RUN_NAME"] = _TEST_RUN @@ -1319,7 +1411,9 @@ def test_get_experiment_run_not_found(self, get_experiment_run_not_found_mock): name=f"{_TEST_CONTEXT_NAME}-{_TEST_RUN}", retry=base._DEFAULT_RETRY ) - @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures( + "get_metadata_store_mock", "get_or_create_default_tb_none_mock" + ) def test_start_run( self, get_experiment_mock, @@ -1351,7 +1445,11 @@ def test_start_run( context=_EXPERIMENT_MOCK.name, child_contexts=[_EXPERIMENT_RUN_MOCK.name] ) - @pytest.mark.usefixtures("get_metadata_store_mock", "get_experiment_mock") + @pytest.mark.usefixtures( + "get_metadata_store_mock", + "get_experiment_mock", + "get_or_create_default_tb_none_mock", + ) def test_start_run_fails_when_run_name_too_long(self): aiplatform.init( @@ -1375,6 +1473,7 @@ def test_start_run_fails_when_run_name_too_long(self): "get_experiment_mock", "create_experiment_run_context_mock", "add_context_children_mock", + "get_or_create_default_tb_none_mock", ) def test_log_params( self, @@ -1398,6 +1497,7 @@ def test_log_params( "get_experiment_mock", "create_experiment_run_context_mock", "add_context_children_mock", + "get_or_create_default_tb_none_mock", ) def test_log_metrics(self, update_context_mock): aiplatform.init( @@ -1418,6 +1518,7 @@ def test_log_metrics(self, update_context_mock): "get_experiment_mock", "create_experiment_run_context_mock", "add_context_children_mock", + "get_or_create_default_tb_none_mock", ) def test_log_classification_metrics( self, @@ -1475,6 +1576,7 @@ def test_log_classification_metrics( "create_experiment_model_artifact_mock", "get_experiment_model_artifact_mock", "get_metadata_store_mock", + "get_or_create_default_tb_none_mock", ) def test_log_model( self, @@ -1615,6 +1717,7 @@ def test_log_time_series_metrics( "get_experiment_mock", "create_experiment_run_context_mock", "add_context_children_mock", + "get_or_create_default_tb_none_mock", ) def test_log_metrics_nest_value_raises_error(self): aiplatform.init( @@ -1629,6 +1732,7 @@ def test_log_metrics_nest_value_raises_error(self): "get_experiment_mock", "create_experiment_run_context_mock", "add_context_children_mock", + "get_or_create_default_tb_none_mock", ) def test_log_params_nest_value_raises_error(self): aiplatform.init( @@ -1644,6 +1748,7 @@ def test_log_params_nest_value_raises_error(self): "create_experiment_run_context_mock", "add_context_children_mock", "get_artifact_mock", + "get_or_create_default_tb_none_mock", ) def test_start_execution_and_assign_artifact( self, @@ -1723,6 +1828,7 @@ def test_start_execution_and_assign_artifact( "get_experiment_mock", "create_experiment_run_context_mock", "add_context_children_mock", + "get_or_create_default_tb_none_mock", ) def test_end_run( self, @@ -1748,6 +1854,7 @@ def test_end_run( "get_experiment_mock", "create_experiment_run_context_mock", "get_pipeline_job_mock", + "get_or_create_default_tb_none_mock", ) def test_log_pipeline_job( self, @@ -1918,3 +2025,34 @@ def test_experiment_run_get_logged_custom_jobs(self, get_custom_job_mock): name=_TEST_CUSTOM_JOB_NAME, retry=base._DEFAULT_RETRY, ) + + +class TestTensorboard: + def test_get_or_create_default_tb_with_existing_default( + self, list_default_tensorboard_mock + ): + tensorboard = metadata._get_or_create_default_tensorboard() + + list_default_tensorboard_mock.assert_called_once_with( + request={ + "parent": f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + "filter": "is_default=true", + } + ) + assert tensorboard.name == _TEST_DEFAULT_TENSORBOARD_NAME + + def test_get_or_create_default_tb_no_existing_default( + self, + list_default_tensorboard_empty_mock, + create_default_tensorboard_mock, + ): + tensorboard = metadata._get_or_create_default_tensorboard() + + list_default_tensorboard_empty_mock.assert_called_once_with( + request={ + "parent": f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + "filter": "is_default=true", + } + ) + create_default_tensorboard_mock.assert_called_once() + assert tensorboard.name == _TEST_DEFAULT_TENSORBOARD_NAME diff --git a/tests/unit/aiplatform/test_publisher_model.py b/tests/unit/aiplatform/test_publisher_model.py new file mode 100644 index 0000000000..5ad0eee4e9 --- /dev/null +++ b/tests/unit/aiplatform/test_publisher_model.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from unittest import mock +from importlib import reload + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.preview import _publisher_model + +from google.cloud.aiplatform.compat.services import ( + model_garden_service_client_v1beta1, +) + + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" + +_TEST_RESOURCE_NAME = "publishers/google/models/chat-bison@001" +_TEST_MODEL_GARDEN_ID = "google/chat-bison@001" +_TEST_INVALID_RESOURCE_NAME = "google.chat-bison@001" + + +@pytest.fixture +def mock_get_publisher_model(): + with mock.patch.object( + model_garden_service_client_v1beta1.ModelGardenServiceClient, + "get_publisher_model", + ) as mock_get_publisher_model: + yield mock_get_publisher_model + + +@pytest.mark.usefixtures("google_auth_mock") +class TestPublisherModel: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_publisher_model_with_resource_name(self, mock_get_publisher_model): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + _ = _publisher_model._PublisherModel(_TEST_RESOURCE_NAME) + mock_get_publisher_model.assert_called_once_with( + name=_TEST_RESOURCE_NAME, retry=base._DEFAULT_RETRY + ) + + def test_init_publisher_model_with_model_garden_id(self, mock_get_publisher_model): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + _ = _publisher_model._PublisherModel(_TEST_MODEL_GARDEN_ID) + mock_get_publisher_model.assert_called_once_with( + name=_TEST_RESOURCE_NAME, retry=base._DEFAULT_RETRY + ) + + def test_init_publisher_model_with_invalid_resource_name( + self, mock_get_publisher_model + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with pytest.raises( + ValueError, + match=( + f"`{_TEST_INVALID_RESOURCE_NAME}` is not a valid PublisherModel " + "resource name or model garden id." + ), + ): + _ = _publisher_model._PublisherModel(_TEST_INVALID_RESOURCE_NAME) diff --git a/tests/unit/aiplatform/test_tensorboard.py b/tests/unit/aiplatform/test_tensorboard.py index 7a999905ea..ae1863c37c 100644 --- a/tests/unit/aiplatform/test_tensorboard.py +++ b/tests/unit/aiplatform/test_tensorboard.py @@ -534,6 +534,52 @@ def test_create_tensorboard_with_timeout_not_explicitly_set( timeout=None, ) + @pytest.mark.usefixtures("get_tensorboard_mock") + def test_create_tensorboard_is_default_true(self, create_tensorboard_mock): + + aiplatform.init( + project=_TEST_PROJECT, + ) + + tensorboard.Tensorboard.create( + display_name=_TEST_DISPLAY_NAME, + is_default=True, + ) + + expected_tensorboard = gca_tensorboard.Tensorboard( + display_name=_TEST_DISPLAY_NAME, is_default=True + ) + + create_tensorboard_mock.assert_called_once_with( + parent=_TEST_PARENT, + tensorboard=expected_tensorboard, + metadata=_TEST_REQUEST_METADATA, + timeout=None, + ) + + @pytest.mark.usefixtures("get_tensorboard_mock") + def test_create_tensorboard_is_default_false(self, create_tensorboard_mock): + + aiplatform.init( + project=_TEST_PROJECT, + ) + + tensorboard.Tensorboard.create( + display_name=_TEST_DISPLAY_NAME, + is_default=False, + ) + + expected_tensorboard = gca_tensorboard.Tensorboard( + display_name=_TEST_DISPLAY_NAME, is_default=False + ) + + create_tensorboard_mock.assert_called_once_with( + parent=_TEST_PARENT, + tensorboard=expected_tensorboard, + metadata=_TEST_REQUEST_METADATA, + timeout=None, + ) + @pytest.mark.usefixtures("get_tensorboard_mock") def test_delete_tensorboard(self, delete_tensorboard_mock): aiplatform.init(project=_TEST_PROJECT) @@ -580,6 +626,40 @@ def test_update_tensorboard_encryption_spec(self, update_tensorboard_mock): metadata=_TEST_REQUEST_METADATA, ) + @pytest.mark.usefixtures("get_tensorboard_mock") + def test_update_tensorboard_is_default_true(self, update_tensorboard_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_tensorboard = tensorboard.Tensorboard(tensorboard_name=_TEST_NAME) + my_tensorboard.update(is_default=True) + + expected_tensorboard = gca_tensorboard.Tensorboard( + name=_TEST_NAME, + is_default=True, + ) + update_tensorboard_mock.assert_called_once_with( + update_mask=field_mask_pb2.FieldMask(paths=["is_default"]), + tensorboard=expected_tensorboard, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_tensorboard_mock") + def test_update_tensorboard_is_default_false(self, update_tensorboard_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_tensorboard = tensorboard.Tensorboard(tensorboard_name=_TEST_NAME) + my_tensorboard.update(is_default=False) + + expected_tensorboard = gca_tensorboard.Tensorboard( + name=_TEST_NAME, + is_default=False, + ) + update_tensorboard_mock.assert_called_once_with( + update_mask=field_mask_pb2.FieldMask(paths=["is_default"]), + tensorboard=expected_tensorboard, + metadata=_TEST_REQUEST_METADATA, + ) + @pytest.mark.usefixtures("google_auth_mock") class TestTensorboardExperiment: diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 47e6caa421..261b3fb4c7 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -35,7 +35,9 @@ from google.cloud import storage from google.cloud.aiplatform import compat, utils from google.cloud.aiplatform.compat.types import pipeline_failure_policy +from google.cloud.aiplatform import datasets from google.cloud.aiplatform.utils import ( + column_transformations_utils, gcs_utils, pipeline_utils, prediction_utils, @@ -485,6 +487,51 @@ def test_timestamped_unique_name(): assert re.match(r"\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-.{5}", name) +class TestColumnTransformationsUtils: + + column_transformations = [ + {"auto": {"column_name": "a"}}, + {"auto": {"column_name": "b"}}, + ] + column_specs = {"a": "auto", "b": "auto"} + + def test_get_default_column_transformations(self): + ds = mock.MagicMock(datasets.TimeSeriesDataset) + ds.column_names = ["a", "b", "target"] + ( + transforms, + columns, + ) = column_transformations_utils.get_default_column_transformations( + dataset=ds, target_column="target" + ) + assert transforms == [ + {"auto": {"column_name": "a"}}, + {"auto": {"column_name": "b"}}, + ] + assert columns == ["a", "b"] + + def test_validate_transformations_with_multiple_configs(self): + with pytest.raises(ValueError): + ( + column_transformations_utils.validate_and_get_column_transformations( + column_transformations=self.column_transformations, + column_specs=self.column_specs, + ) + ) + + def test_validate_transformations_with_column_specs(self): + actual = column_transformations_utils.validate_and_get_column_transformations( + column_specs=self.column_specs + ) + assert actual == self.column_transformations + + def test_validate_transformations_with_column_transformations(self): + actual = column_transformations_utils.validate_and_get_column_transformations( + column_transformations=self.column_transformations + ) + assert actual == self.column_transformations + + @pytest.mark.usefixtures("google_auth_mock") class TestGcsUtils: def test_upload_to_gcs(self, json_file, mock_storage_blob_upload_from_filename): diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py index 1a44844cac..a6d86a3be8 100644 --- a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -2668,6 +2668,303 @@ async def test_undeploy_model_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + endpoint_service.MutateDeployedModelRequest, + dict, + ], +) +def test_mutate_deployed_model(request_type, transport: str = "grpc"): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == endpoint_service.MutateDeployedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_mutate_deployed_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + client.mutate_deployed_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == endpoint_service.MutateDeployedModelRequest() + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_async( + transport: str = "grpc_asyncio", + request_type=endpoint_service.MutateDeployedModelRequest, +): + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == endpoint_service.MutateDeployedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_async_from_dict(): + await test_mutate_deployed_model_async(request_type=dict) + + +def test_mutate_deployed_model_field_headers(): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.MutateDeployedModelRequest() + + request.endpoint = "endpoint_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "endpoint=endpoint_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.MutateDeployedModelRequest() + + request.endpoint = "endpoint_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "endpoint=endpoint_value", + ) in kw["metadata"] + + +def test_mutate_deployed_model_flattened(): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.mutate_deployed_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].endpoint + mock_val = "endpoint_value" + assert arg == mock_val + arg = args[0].deployed_model + mock_val = gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_mutate_deployed_model_flattened_error(): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.mutate_deployed_model( + endpoint_service.MutateDeployedModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.mutate_deployed_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].endpoint + mock_val = "endpoint_value" + assert arg == mock_val + arg = args[0].deployed_model + mock_val = gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.mutate_deployed_model( + endpoint_service.MutateDeployedModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.EndpointServiceGrpcTransport( @@ -2812,6 +3109,7 @@ def test_endpoint_service_base_transport(): "delete_endpoint", "deploy_model", "undeploy_model", + "mutate_deployed_model", "set_iam_policy", "get_iam_policy", "test_iam_permissions", diff --git a/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py index 4b2317727f..0b20d153ac 100644 --- a/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py @@ -1017,6 +1017,7 @@ def test_get_featurestore(request_type, transport: str = "grpc"): name="name_value", etag="etag_value", state=featurestore.Featurestore.State.STABLE, + online_storage_ttl_days=2460, ) response = client.get_featurestore(request) @@ -1030,6 +1031,7 @@ def test_get_featurestore(request_type, transport: str = "grpc"): assert response.name == "name_value" assert response.etag == "etag_value" assert response.state == featurestore.Featurestore.State.STABLE + assert response.online_storage_ttl_days == 2460 def test_get_featurestore_empty_call(): @@ -1070,6 +1072,7 @@ async def test_get_featurestore_async( name="name_value", etag="etag_value", state=featurestore.Featurestore.State.STABLE, + online_storage_ttl_days=2460, ) ) response = await client.get_featurestore(request) @@ -1084,6 +1087,7 @@ async def test_get_featurestore_async( assert response.name == "name_value" assert response.etag == "etag_value" assert response.state == featurestore.Featurestore.State.STABLE + assert response.online_storage_ttl_days == 2460 @pytest.mark.asyncio @@ -2464,6 +2468,7 @@ def test_get_entity_type(request_type, transport: str = "grpc"): name="name_value", description="description_value", etag="etag_value", + offline_storage_ttl_days=2554, ) response = client.get_entity_type(request) @@ -2477,6 +2482,7 @@ def test_get_entity_type(request_type, transport: str = "grpc"): assert response.name == "name_value" assert response.description == "description_value" assert response.etag == "etag_value" + assert response.offline_storage_ttl_days == 2554 def test_get_entity_type_empty_call(): @@ -2517,6 +2523,7 @@ async def test_get_entity_type_async( name="name_value", description="description_value", etag="etag_value", + offline_storage_ttl_days=2554, ) ) response = await client.get_entity_type(request) @@ -2531,6 +2538,7 @@ async def test_get_entity_type_async( assert response.name == "name_value" assert response.description == "description_value" assert response.etag == "etag_value" + assert response.offline_storage_ttl_days == 2554 @pytest.mark.asyncio @@ -3150,6 +3158,7 @@ def test_update_entity_type(request_type, transport: str = "grpc"): name="name_value", description="description_value", etag="etag_value", + offline_storage_ttl_days=2554, ) response = client.update_entity_type(request) @@ -3163,6 +3172,7 @@ def test_update_entity_type(request_type, transport: str = "grpc"): assert response.name == "name_value" assert response.description == "description_value" assert response.etag == "etag_value" + assert response.offline_storage_ttl_days == 2554 def test_update_entity_type_empty_call(): @@ -3207,6 +3217,7 @@ async def test_update_entity_type_async( name="name_value", description="description_value", etag="etag_value", + offline_storage_ttl_days=2554, ) ) response = await client.update_entity_type(request) @@ -3221,6 +3232,7 @@ async def test_update_entity_type_async( assert response.name == "name_value" assert response.description == "description_value" assert response.etag == "etag_value" + assert response.offline_storage_ttl_days == 2554 @pytest.mark.asyncio diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index e409919726..7ef170139e 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -11743,10 +11743,39 @@ def test_parse_batch_prediction_job_path(): assert expected == actual -def test_custom_job_path(): +def test_context_path(): project = "cuttlefish" location = "mussel" - custom_job = "winkle" + metadata_store = "winkle" + context = "nautilus" + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + actual = JobServiceClient.context_path(project, location, metadata_store, context) + assert expected == actual + + +def test_parse_context_path(): + expected = { + "project": "scallop", + "location": "abalone", + "metadata_store": "squid", + "context": "clam", + } + path = JobServiceClient.context_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_context_path(path) + assert expected == actual + + +def test_custom_job_path(): + project = "whelk" + location = "octopus" + custom_job = "oyster" expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( project=project, location=location, @@ -11758,9 +11787,9 @@ def test_custom_job_path(): def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "custom_job": "mussel", } path = JobServiceClient.custom_job_path(**expected) @@ -11770,9 +11799,9 @@ def test_parse_custom_job_path(): def test_data_labeling_job_path(): - project = "squid" - location = "clam" - data_labeling_job = "whelk" + project = "winkle" + location = "nautilus" + data_labeling_job = "scallop" expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( project=project, location=location, @@ -11786,9 +11815,9 @@ def test_data_labeling_job_path(): def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", + "project": "abalone", + "location": "squid", + "data_labeling_job": "clam", } path = JobServiceClient.data_labeling_job_path(**expected) @@ -11798,9 +11827,9 @@ def test_parse_data_labeling_job_path(): def test_dataset_path(): - project = "cuttlefish" - location = "mussel" - dataset = "winkle" + project = "whelk" + location = "octopus" + dataset = "oyster" expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, location=location, @@ -11812,9 +11841,9 @@ def test_dataset_path(): def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", } path = JobServiceClient.dataset_path(**expected) @@ -11824,9 +11853,9 @@ def test_parse_dataset_path(): def test_endpoint_path(): - project = "squid" - location = "clam" - endpoint = "whelk" + project = "winkle" + location = "nautilus" + endpoint = "scallop" expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( project=project, location=location, @@ -11838,9 +11867,9 @@ def test_endpoint_path(): def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "abalone", + "location": "squid", + "endpoint": "clam", } path = JobServiceClient.endpoint_path(**expected) @@ -11850,9 +11879,9 @@ def test_parse_endpoint_path(): def test_hyperparameter_tuning_job_path(): - project = "cuttlefish" - location = "mussel" - hyperparameter_tuning_job = "winkle" + project = "whelk" + location = "octopus" + hyperparameter_tuning_job = "oyster" expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( project=project, location=location, @@ -11866,9 +11895,9 @@ def test_hyperparameter_tuning_job_path(): def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "hyperparameter_tuning_job": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "hyperparameter_tuning_job": "mussel", } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -11878,9 +11907,9 @@ def test_parse_hyperparameter_tuning_job_path(): def test_model_path(): - project = "squid" - location = "clam" - model = "whelk" + project = "winkle" + location = "nautilus" + model = "scallop" expected = "projects/{project}/locations/{location}/models/{model}".format( project=project, location=location, @@ -11892,9 +11921,9 @@ def test_model_path(): def test_parse_model_path(): expected = { - "project": "octopus", - "location": "oyster", - "model": "nudibranch", + "project": "abalone", + "location": "squid", + "model": "clam", } path = JobServiceClient.model_path(**expected) @@ -11904,9 +11933,9 @@ def test_parse_model_path(): def test_model_deployment_monitoring_job_path(): - project = "cuttlefish" - location = "mussel" - model_deployment_monitoring_job = "winkle" + project = "whelk" + location = "octopus" + model_deployment_monitoring_job = "oyster" expected = "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format( project=project, location=location, @@ -11920,9 +11949,9 @@ def test_model_deployment_monitoring_job_path(): def test_parse_model_deployment_monitoring_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model_deployment_monitoring_job": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "model_deployment_monitoring_job": "mussel", } path = JobServiceClient.model_deployment_monitoring_job_path(**expected) @@ -11932,9 +11961,9 @@ def test_parse_model_deployment_monitoring_job_path(): def test_nas_job_path(): - project = "squid" - location = "clam" - nas_job = "whelk" + project = "winkle" + location = "nautilus" + nas_job = "scallop" expected = "projects/{project}/locations/{location}/nasJobs/{nas_job}".format( project=project, location=location, @@ -11946,9 +11975,9 @@ def test_nas_job_path(): def test_parse_nas_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "nas_job": "nudibranch", + "project": "abalone", + "location": "squid", + "nas_job": "clam", } path = JobServiceClient.nas_job_path(**expected) @@ -11958,10 +11987,10 @@ def test_parse_nas_job_path(): def test_nas_trial_detail_path(): - project = "cuttlefish" - location = "mussel" - nas_job = "winkle" - nas_trial_detail = "nautilus" + project = "whelk" + location = "octopus" + nas_job = "oyster" + nas_trial_detail = "nudibranch" expected = "projects/{project}/locations/{location}/nasJobs/{nas_job}/nasTrialDetails/{nas_trial_detail}".format( project=project, location=location, @@ -11976,10 +12005,10 @@ def test_nas_trial_detail_path(): def test_parse_nas_trial_detail_path(): expected = { - "project": "scallop", - "location": "abalone", - "nas_job": "squid", - "nas_trial_detail": "clam", + "project": "cuttlefish", + "location": "mussel", + "nas_job": "winkle", + "nas_trial_detail": "nautilus", } path = JobServiceClient.nas_trial_detail_path(**expected) @@ -11989,8 +12018,8 @@ def test_parse_nas_trial_detail_path(): def test_network_path(): - project = "whelk" - network = "octopus" + project = "scallop" + network = "abalone" expected = "projects/{project}/global/networks/{network}".format( project=project, network=network, @@ -12001,8 +12030,8 @@ def test_network_path(): def test_parse_network_path(): expected = { - "project": "oyster", - "network": "nudibranch", + "project": "squid", + "network": "clam", } path = JobServiceClient.network_path(**expected) @@ -12012,9 +12041,9 @@ def test_parse_network_path(): def test_tensorboard_path(): - project = "cuttlefish" - location = "mussel" - tensorboard = "winkle" + project = "whelk" + location = "octopus" + tensorboard = "oyster" expected = ( "projects/{project}/locations/{location}/tensorboards/{tensorboard}".format( project=project, @@ -12028,9 +12057,9 @@ def test_tensorboard_path(): def test_parse_tensorboard_path(): expected = { - "project": "nautilus", - "location": "scallop", - "tensorboard": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "tensorboard": "mussel", } path = JobServiceClient.tensorboard_path(**expected) @@ -12040,10 +12069,10 @@ def test_parse_tensorboard_path(): def test_trial_path(): - project = "squid" - location = "clam" - study = "whelk" - trial = "octopus" + project = "winkle" + location = "nautilus" + study = "scallop" + trial = "abalone" expected = ( "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( project=project, @@ -12058,10 +12087,10 @@ def test_trial_path(): def test_parse_trial_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "study": "cuttlefish", - "trial": "mussel", + "project": "squid", + "location": "clam", + "study": "whelk", + "trial": "octopus", } path = JobServiceClient.trial_path(**expected) @@ -12071,7 +12100,7 @@ def test_parse_trial_path(): def test_common_billing_account_path(): - billing_account = "winkle" + billing_account = "oyster" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -12081,7 +12110,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "nudibranch", } path = JobServiceClient.common_billing_account_path(**expected) @@ -12091,7 +12120,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "scallop" + folder = "cuttlefish" expected = "folders/{folder}".format( folder=folder, ) @@ -12101,7 +12130,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "mussel", } path = JobServiceClient.common_folder_path(**expected) @@ -12111,7 +12140,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "squid" + organization = "winkle" expected = "organizations/{organization}".format( organization=organization, ) @@ -12121,7 +12150,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "nautilus", } path = JobServiceClient.common_organization_path(**expected) @@ -12131,7 +12160,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "whelk" + project = "scallop" expected = "projects/{project}".format( project=project, ) @@ -12141,7 +12170,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "abalone", } path = JobServiceClient.common_project_path(**expected) @@ -12151,8 +12180,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "oyster" - location = "nudibranch" + project = "squid" + location = "clam" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -12163,8 +12192,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "whelk", + "location": "octopus", } path = JobServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 9055645a91..8fc5db291f 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -2030,19 +2030,22 @@ def test_parse_dataset_path(): def test_dataset_path(): project = "squid" - dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format( + location = "clam" + dataset = "whelk" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", } path = MigrationServiceClient.dataset_path(**expected) @@ -2052,22 +2055,19 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "oyster" - location = "nudibranch" - dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project = "cuttlefish" + dataset = "mussel" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", + "project": "winkle", "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index c61ca6d9b4..5be074e237 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -2670,6 +2670,303 @@ async def test_undeploy_model_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + endpoint_service.MutateDeployedModelRequest, + dict, + ], +) +def test_mutate_deployed_model(request_type, transport: str = "grpc"): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == endpoint_service.MutateDeployedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_mutate_deployed_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + client.mutate_deployed_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == endpoint_service.MutateDeployedModelRequest() + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_async( + transport: str = "grpc_asyncio", + request_type=endpoint_service.MutateDeployedModelRequest, +): + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == endpoint_service.MutateDeployedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_async_from_dict(): + await test_mutate_deployed_model_async(request_type=dict) + + +def test_mutate_deployed_model_field_headers(): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.MutateDeployedModelRequest() + + request.endpoint = "endpoint_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "endpoint=endpoint_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.MutateDeployedModelRequest() + + request.endpoint = "endpoint_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "endpoint=endpoint_value", + ) in kw["metadata"] + + +def test_mutate_deployed_model_flattened(): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.mutate_deployed_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].endpoint + mock_val = "endpoint_value" + assert arg == mock_val + arg = args[0].deployed_model + mock_val = gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_mutate_deployed_model_flattened_error(): + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.mutate_deployed_model( + endpoint_service.MutateDeployedModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.mutate_deployed_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.mutate_deployed_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].endpoint + mock_val = "endpoint_value" + assert arg == mock_val + arg = args[0].deployed_model + mock_val = gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_mutate_deployed_model_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.mutate_deployed_model( + endpoint_service.MutateDeployedModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.EndpointServiceGrpcTransport( @@ -2814,6 +3111,7 @@ def test_endpoint_service_base_transport(): "delete_endpoint", "deploy_model", "undeploy_model", + "mutate_deployed_model", "set_iam_policy", "get_iam_policy", "test_iam_permissions", diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index a4ea1e4fa7..6736a16778 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -11746,10 +11746,39 @@ def test_parse_batch_prediction_job_path(): assert expected == actual -def test_custom_job_path(): +def test_context_path(): project = "cuttlefish" location = "mussel" - custom_job = "winkle" + metadata_store = "winkle" + context = "nautilus" + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + actual = JobServiceClient.context_path(project, location, metadata_store, context) + assert expected == actual + + +def test_parse_context_path(): + expected = { + "project": "scallop", + "location": "abalone", + "metadata_store": "squid", + "context": "clam", + } + path = JobServiceClient.context_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_context_path(path) + assert expected == actual + + +def test_custom_job_path(): + project = "whelk" + location = "octopus" + custom_job = "oyster" expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( project=project, location=location, @@ -11761,9 +11790,9 @@ def test_custom_job_path(): def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "custom_job": "mussel", } path = JobServiceClient.custom_job_path(**expected) @@ -11773,9 +11802,9 @@ def test_parse_custom_job_path(): def test_data_labeling_job_path(): - project = "squid" - location = "clam" - data_labeling_job = "whelk" + project = "winkle" + location = "nautilus" + data_labeling_job = "scallop" expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( project=project, location=location, @@ -11789,9 +11818,9 @@ def test_data_labeling_job_path(): def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", + "project": "abalone", + "location": "squid", + "data_labeling_job": "clam", } path = JobServiceClient.data_labeling_job_path(**expected) @@ -11801,9 +11830,9 @@ def test_parse_data_labeling_job_path(): def test_dataset_path(): - project = "cuttlefish" - location = "mussel" - dataset = "winkle" + project = "whelk" + location = "octopus" + dataset = "oyster" expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, location=location, @@ -11815,9 +11844,9 @@ def test_dataset_path(): def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", } path = JobServiceClient.dataset_path(**expected) @@ -11827,9 +11856,9 @@ def test_parse_dataset_path(): def test_endpoint_path(): - project = "squid" - location = "clam" - endpoint = "whelk" + project = "winkle" + location = "nautilus" + endpoint = "scallop" expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( project=project, location=location, @@ -11841,9 +11870,9 @@ def test_endpoint_path(): def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "abalone", + "location": "squid", + "endpoint": "clam", } path = JobServiceClient.endpoint_path(**expected) @@ -11853,9 +11882,9 @@ def test_parse_endpoint_path(): def test_hyperparameter_tuning_job_path(): - project = "cuttlefish" - location = "mussel" - hyperparameter_tuning_job = "winkle" + project = "whelk" + location = "octopus" + hyperparameter_tuning_job = "oyster" expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( project=project, location=location, @@ -11869,9 +11898,9 @@ def test_hyperparameter_tuning_job_path(): def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "hyperparameter_tuning_job": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "hyperparameter_tuning_job": "mussel", } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -11881,9 +11910,9 @@ def test_parse_hyperparameter_tuning_job_path(): def test_model_path(): - project = "squid" - location = "clam" - model = "whelk" + project = "winkle" + location = "nautilus" + model = "scallop" expected = "projects/{project}/locations/{location}/models/{model}".format( project=project, location=location, @@ -11895,9 +11924,9 @@ def test_model_path(): def test_parse_model_path(): expected = { - "project": "octopus", - "location": "oyster", - "model": "nudibranch", + "project": "abalone", + "location": "squid", + "model": "clam", } path = JobServiceClient.model_path(**expected) @@ -11907,9 +11936,9 @@ def test_parse_model_path(): def test_model_deployment_monitoring_job_path(): - project = "cuttlefish" - location = "mussel" - model_deployment_monitoring_job = "winkle" + project = "whelk" + location = "octopus" + model_deployment_monitoring_job = "oyster" expected = "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format( project=project, location=location, @@ -11923,9 +11952,9 @@ def test_model_deployment_monitoring_job_path(): def test_parse_model_deployment_monitoring_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model_deployment_monitoring_job": "abalone", + "project": "nudibranch", + "location": "cuttlefish", + "model_deployment_monitoring_job": "mussel", } path = JobServiceClient.model_deployment_monitoring_job_path(**expected) @@ -11935,9 +11964,9 @@ def test_parse_model_deployment_monitoring_job_path(): def test_nas_job_path(): - project = "squid" - location = "clam" - nas_job = "whelk" + project = "winkle" + location = "nautilus" + nas_job = "scallop" expected = "projects/{project}/locations/{location}/nasJobs/{nas_job}".format( project=project, location=location, @@ -11949,9 +11978,9 @@ def test_nas_job_path(): def test_parse_nas_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "nas_job": "nudibranch", + "project": "abalone", + "location": "squid", + "nas_job": "clam", } path = JobServiceClient.nas_job_path(**expected) @@ -11961,10 +11990,10 @@ def test_parse_nas_job_path(): def test_nas_trial_detail_path(): - project = "cuttlefish" - location = "mussel" - nas_job = "winkle" - nas_trial_detail = "nautilus" + project = "whelk" + location = "octopus" + nas_job = "oyster" + nas_trial_detail = "nudibranch" expected = "projects/{project}/locations/{location}/nasJobs/{nas_job}/nasTrialDetails/{nas_trial_detail}".format( project=project, location=location, @@ -11979,10 +12008,10 @@ def test_nas_trial_detail_path(): def test_parse_nas_trial_detail_path(): expected = { - "project": "scallop", - "location": "abalone", - "nas_job": "squid", - "nas_trial_detail": "clam", + "project": "cuttlefish", + "location": "mussel", + "nas_job": "winkle", + "nas_trial_detail": "nautilus", } path = JobServiceClient.nas_trial_detail_path(**expected) @@ -11992,8 +12021,8 @@ def test_parse_nas_trial_detail_path(): def test_network_path(): - project = "whelk" - network = "octopus" + project = "scallop" + network = "abalone" expected = "projects/{project}/global/networks/{network}".format( project=project, network=network, @@ -12004,8 +12033,8 @@ def test_network_path(): def test_parse_network_path(): expected = { - "project": "oyster", - "network": "nudibranch", + "project": "squid", + "network": "clam", } path = JobServiceClient.network_path(**expected) @@ -12015,8 +12044,8 @@ def test_parse_network_path(): def test_notification_channel_path(): - project = "cuttlefish" - notification_channel = "mussel" + project = "whelk" + notification_channel = "octopus" expected = "projects/{project}/notificationChannels/{notification_channel}".format( project=project, notification_channel=notification_channel, @@ -12027,8 +12056,8 @@ def test_notification_channel_path(): def test_parse_notification_channel_path(): expected = { - "project": "winkle", - "notification_channel": "nautilus", + "project": "oyster", + "notification_channel": "nudibranch", } path = JobServiceClient.notification_channel_path(**expected) @@ -12038,9 +12067,9 @@ def test_parse_notification_channel_path(): def test_tensorboard_path(): - project = "scallop" - location = "abalone" - tensorboard = "squid" + project = "cuttlefish" + location = "mussel" + tensorboard = "winkle" expected = ( "projects/{project}/locations/{location}/tensorboards/{tensorboard}".format( project=project, @@ -12054,9 +12083,9 @@ def test_tensorboard_path(): def test_parse_tensorboard_path(): expected = { - "project": "clam", - "location": "whelk", - "tensorboard": "octopus", + "project": "nautilus", + "location": "scallop", + "tensorboard": "abalone", } path = JobServiceClient.tensorboard_path(**expected) @@ -12066,10 +12095,10 @@ def test_parse_tensorboard_path(): def test_trial_path(): - project = "oyster" - location = "nudibranch" - study = "cuttlefish" - trial = "mussel" + project = "squid" + location = "clam" + study = "whelk" + trial = "octopus" expected = ( "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( project=project, @@ -12084,10 +12113,10 @@ def test_trial_path(): def test_parse_trial_path(): expected = { - "project": "winkle", - "location": "nautilus", - "study": "scallop", - "trial": "abalone", + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", } path = JobServiceClient.trial_path(**expected) @@ -12097,7 +12126,7 @@ def test_parse_trial_path(): def test_common_billing_account_path(): - billing_account = "squid" + billing_account = "winkle" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -12107,7 +12136,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "nautilus", } path = JobServiceClient.common_billing_account_path(**expected) @@ -12117,7 +12146,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "whelk" + folder = "scallop" expected = "folders/{folder}".format( folder=folder, ) @@ -12127,7 +12156,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "abalone", } path = JobServiceClient.common_folder_path(**expected) @@ -12137,7 +12166,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "oyster" + organization = "squid" expected = "organizations/{organization}".format( organization=organization, ) @@ -12147,7 +12176,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "clam", } path = JobServiceClient.common_organization_path(**expected) @@ -12157,7 +12186,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "cuttlefish" + project = "whelk" expected = "projects/{project}".format( project=project, ) @@ -12167,7 +12196,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "octopus", } path = JobServiceClient.common_project_path(**expected) @@ -12177,8 +12206,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "winkle" - location = "nautilus" + project = "oyster" + location = "nudibranch" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -12189,8 +12218,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "cuttlefish", + "location": "mussel", } path = JobServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py new file mode 100644 index 0000000000..0d8d1c6fed --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py @@ -0,0 +1,3208 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers + +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.model_garden_service import ( + ModelGardenServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.model_garden_service import ( + ModelGardenServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.model_garden_service import transports +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model_garden_service +from google.cloud.aiplatform_v1beta1.types import publisher_model +from google.cloud.location import locations_pb2 +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import options_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert ModelGardenServiceClient._get_default_mtls_endpoint(None) is None + assert ( + ModelGardenServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelGardenServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelGardenServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelGardenServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelGardenServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (ModelGardenServiceClient, "grpc"), + (ModelGardenServiceAsyncClient, "grpc_asyncio"), + ], +) +def test_model_garden_service_client_from_service_account_info( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ("aiplatform.googleapis.com:443") + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.ModelGardenServiceGrpcTransport, "grpc"), + (transports.ModelGardenServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_model_garden_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (ModelGardenServiceClient, "grpc"), + (ModelGardenServiceAsyncClient, "grpc_asyncio"), + ], +) +def test_model_garden_service_client_from_service_account_file( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ("aiplatform.googleapis.com:443") + + +def test_model_garden_service_client_get_transport_class(): + transport = ModelGardenServiceClient.get_transport_class() + available_transports = [ + transports.ModelGardenServiceGrpcTransport, + ] + assert transport in available_transports + + transport = ModelGardenServiceClient.get_transport_class("grpc") + assert transport == transports.ModelGardenServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelGardenServiceClient, transports.ModelGardenServiceGrpcTransport, "grpc"), + ( + ModelGardenServiceAsyncClient, + transports.ModelGardenServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + ModelGardenServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelGardenServiceClient), +) +@mock.patch.object( + ModelGardenServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelGardenServiceAsyncClient), +) +def test_model_garden_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(ModelGardenServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(ModelGardenServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + ModelGardenServiceClient, + transports.ModelGardenServiceGrpcTransport, + "grpc", + "true", + ), + ( + ModelGardenServiceAsyncClient, + transports.ModelGardenServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + ModelGardenServiceClient, + transports.ModelGardenServiceGrpcTransport, + "grpc", + "false", + ), + ( + ModelGardenServiceAsyncClient, + transports.ModelGardenServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ModelGardenServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelGardenServiceClient), +) +@mock.patch.object( + ModelGardenServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelGardenServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_model_garden_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class", [ModelGardenServiceClient, ModelGardenServiceAsyncClient] +) +@mock.patch.object( + ModelGardenServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelGardenServiceClient), +) +@mock.patch.object( + ModelGardenServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelGardenServiceAsyncClient), +) +def test_model_garden_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelGardenServiceClient, transports.ModelGardenServiceGrpcTransport, "grpc"), + ( + ModelGardenServiceAsyncClient, + transports.ModelGardenServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_garden_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + ModelGardenServiceClient, + transports.ModelGardenServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + ModelGardenServiceAsyncClient, + transports.ModelGardenServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_model_garden_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_model_garden_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_garden_service.transports.ModelGardenServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = ModelGardenServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + ModelGardenServiceClient, + transports.ModelGardenServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + ModelGardenServiceAsyncClient, + transports.ModelGardenServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_model_garden_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + model_garden_service.GetPublisherModelRequest, + dict, + ], +) +def test_get_publisher_model(request_type, transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_publisher_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = publisher_model.PublisherModel( + name="name_value", + version_id="version_id_value", + open_source_category=publisher_model.PublisherModel.OpenSourceCategory.PROPRIETARY, + frameworks=["frameworks_value"], + publisher_model_template="publisher_model_template_value", + ) + response = client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_garden_service.GetPublisherModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, publisher_model.PublisherModel) + assert response.name == "name_value" + assert response.version_id == "version_id_value" + assert ( + response.open_source_category + == publisher_model.PublisherModel.OpenSourceCategory.PROPRIETARY + ) + assert response.frameworks == ["frameworks_value"] + assert response.publisher_model_template == "publisher_model_template_value" + + +def test_get_publisher_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_publisher_model), "__call__" + ) as call: + client.get_publisher_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_garden_service.GetPublisherModelRequest() + + +@pytest.mark.asyncio +async def test_get_publisher_model_async( + transport: str = "grpc_asyncio", + request_type=model_garden_service.GetPublisherModelRequest, +): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_publisher_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + publisher_model.PublisherModel( + name="name_value", + version_id="version_id_value", + open_source_category=publisher_model.PublisherModel.OpenSourceCategory.PROPRIETARY, + frameworks=["frameworks_value"], + publisher_model_template="publisher_model_template_value", + ) + ) + response = await client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_garden_service.GetPublisherModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, publisher_model.PublisherModel) + assert response.name == "name_value" + assert response.version_id == "version_id_value" + assert ( + response.open_source_category + == publisher_model.PublisherModel.OpenSourceCategory.PROPRIETARY + ) + assert response.frameworks == ["frameworks_value"] + assert response.publisher_model_template == "publisher_model_template_value" + + +@pytest.mark.asyncio +async def test_get_publisher_model_async_from_dict(): + await test_get_publisher_model_async(request_type=dict) + + +def test_get_publisher_model_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_garden_service.GetPublisherModelRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_publisher_model), "__call__" + ) as call: + call.return_value = publisher_model.PublisherModel() + client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_publisher_model_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_garden_service.GetPublisherModelRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_publisher_model), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + publisher_model.PublisherModel() + ) + await client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_publisher_model_flattened(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_publisher_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = publisher_model.PublisherModel() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_publisher_model( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_publisher_model_flattened_error(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_publisher_model( + model_garden_service.GetPublisherModelRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_publisher_model_flattened_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_publisher_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = publisher_model.PublisherModel() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + publisher_model.PublisherModel() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_publisher_model( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_publisher_model_flattened_error_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_publisher_model( + model_garden_service.GetPublisherModelRequest(), + name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ModelGardenServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ModelGardenServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelGardenServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.ModelGardenServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelGardenServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelGardenServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ModelGardenServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelGardenServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelGardenServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = ModelGardenServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelGardenServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ModelGardenServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelGardenServiceGrpcTransport, + transports.ModelGardenServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + ], +) +def test_transport_kind(transport_name): + transport = ModelGardenServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ModelGardenServiceGrpcTransport, + ) + + +def test_model_garden_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.ModelGardenServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_model_garden_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_garden_service.transports.ModelGardenServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.ModelGardenServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "get_publisher_model", + "set_iam_policy", + "get_iam_policy", + "test_iam_permissions", + "get_location", + "list_locations", + "get_operation", + "wait_operation", + "cancel_operation", + "delete_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_model_garden_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_garden_service.transports.ModelGardenServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ModelGardenServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_model_garden_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_garden_service.transports.ModelGardenServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ModelGardenServiceTransport() + adc.assert_called_once() + + +def test_model_garden_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + ModelGardenServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelGardenServiceGrpcTransport, + transports.ModelGardenServiceGrpcAsyncIOTransport, + ], +) +def test_model_garden_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelGardenServiceGrpcTransport, + transports.ModelGardenServiceGrpcAsyncIOTransport, + ], +) +def test_model_garden_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.ModelGardenServiceGrpcTransport, grpc_helpers), + (transports.ModelGardenServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_model_garden_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=["1", "2"], + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelGardenServiceGrpcTransport, + transports.ModelGardenServiceGrpcAsyncIOTransport, + ], +) +def test_model_garden_service_grpc_transport_client_cert_source_for_mtls( + transport_class, +): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + ], +) +def test_model_garden_service_host_no_port(transport_name): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ("aiplatform.googleapis.com:443") + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + ], +) +def test_model_garden_service_host_with_port(transport_name): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ("aiplatform.googleapis.com:8000") + + +def test_model_garden_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelGardenServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_model_garden_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelGardenServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelGardenServiceGrpcTransport, + transports.ModelGardenServiceGrpcAsyncIOTransport, + ], +) +def test_model_garden_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelGardenServiceGrpcTransport, + transports.ModelGardenServiceGrpcAsyncIOTransport, + ], +) +def test_model_garden_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_publisher_model_path(): + publisher = "squid" + model = "clam" + expected = "publishers/{publisher}/models/{model}".format( + publisher=publisher, + model=model, + ) + actual = ModelGardenServiceClient.publisher_model_path(publisher, model) + assert expected == actual + + +def test_parse_publisher_model_path(): + expected = { + "publisher": "whelk", + "model": "octopus", + } + path = ModelGardenServiceClient.publisher_model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelGardenServiceClient.parse_publisher_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "oyster" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = ModelGardenServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + } + path = ModelGardenServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ModelGardenServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "cuttlefish" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = ModelGardenServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + } + path = ModelGardenServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ModelGardenServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "winkle" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = ModelGardenServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + } + path = ModelGardenServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ModelGardenServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "scallop" + expected = "projects/{project}".format( + project=project, + ) + actual = ModelGardenServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + } + path = ModelGardenServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ModelGardenServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "squid" + location = "clam" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = ModelGardenServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + } + path = ModelGardenServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ModelGardenServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.ModelGardenServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.ModelGardenServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = ModelGardenServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_delete_operation(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.DeleteOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_operation_async(transport: str = "grpc"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.DeleteOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_operation_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.DeleteOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + call.return_value = None + + client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_operation_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.DeleteOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_delete_operation_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.delete_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_delete_operation_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_cancel_operation(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.CancelOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_operation_async(transport: str = "grpc"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.CancelOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_operation_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.CancelOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + call.return_value = None + + client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_operation_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.CancelOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_cancel_operation_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_cancel_operation_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.cancel_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_wait_operation(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.WaitOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_wait_operation(transport: str = "grpc"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.WaitOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_wait_operation_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.WaitOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_wait_operation_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.WaitOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_wait_operation_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.wait_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_wait_operation_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.wait_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_get_operation(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_locations(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.ListLocationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.ListLocationsResponse() + response = client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.ListLocationsResponse) + + +@pytest.mark.asyncio +async def test_list_locations_async(transport: str = "grpc"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.ListLocationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + response = await client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.ListLocationsResponse) + + +def test_list_locations_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.ListLocationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + call.return_value = locations_pb2.ListLocationsResponse() + + client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_locations_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.ListLocationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + await client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_locations_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.ListLocationsResponse() + + response = client.list_locations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_locations_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + response = await client.list_locations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_get_location(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.GetLocationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.Location() + response = client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.Location) + + +@pytest.mark.asyncio +async def test_get_location_async(transport: str = "grpc_asyncio"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.GetLocationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + response = await client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.Location) + + +def test_get_location_field_headers(): + client = ModelGardenServiceClient(credentials=ga_credentials.AnonymousCredentials()) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.GetLocationRequest() + request.name = "locations/abc" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + call.return_value = locations_pb2.Location() + + client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations/abc", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_location_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials() + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.GetLocationRequest() + request.name = "locations/abc" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + await client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations/abc", + ) in kw["metadata"] + + +def test_get_location_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.Location() + + response = client.get_location( + request={ + "name": "locations/abc", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_location_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + response = await client.get_location( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_set_iam_policy(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.SetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + response = client.set_iam_policy(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +@pytest.mark.asyncio +async def test_set_iam_policy_async(transport: str = "grpc_asyncio"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.SetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + ) + response = await client.set_iam_policy(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +def test_set_iam_policy_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.SetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + call.return_value = policy_pb2.Policy() + + client.set_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_set_iam_policy_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.SetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + await client.set_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_set_iam_policy_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy() + + response = client.set_iam_policy( + request={ + "resource": "resource_value", + "policy": policy_pb2.Policy(version=774), + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_set_iam_policy_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + response = await client.set_iam_policy( + request={ + "resource": "resource_value", + "policy": policy_pb2.Policy(version=774), + } + ) + call.assert_called() + + +def test_get_iam_policy(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.GetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + + response = client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +@pytest.mark.asyncio +async def test_get_iam_policy_async(transport: str = "grpc_asyncio"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.GetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + ) + + response = await client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +def test_get_iam_policy_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.GetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + call.return_value = policy_pb2.Policy() + + client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_iam_policy_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.GetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + await client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_get_iam_policy_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy() + + response = client.get_iam_policy( + request={ + "resource": "resource_value", + "options": options_pb2.GetPolicyOptions(requested_policy_version=2598), + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_iam_policy_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + response = await client.get_iam_policy( + request={ + "resource": "resource_value", + "options": options_pb2.GetPolicyOptions(requested_policy_version=2598), + } + ) + call.assert_called() + + +def test_test_iam_permissions(transport: str = "grpc"): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.TestIamPermissionsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iam_policy_pb2.TestIamPermissionsResponse( + permissions=["permissions_value"], + ) + + response = client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, iam_policy_pb2.TestIamPermissionsResponse) + + assert response.permissions == ["permissions_value"] + + +@pytest.mark.asyncio +async def test_test_iam_permissions_async(transport: str = "grpc_asyncio"): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.TestIamPermissionsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse( + permissions=["permissions_value"], + ) + ) + + response = await client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, iam_policy_pb2.TestIamPermissionsResponse) + + assert response.permissions == ["permissions_value"] + + +def test_test_iam_permissions_field_headers(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.TestIamPermissionsRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + call.return_value = iam_policy_pb2.TestIamPermissionsResponse() + + client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_test_iam_permissions_field_headers_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.TestIamPermissionsRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse() + ) + + await client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_test_iam_permissions_from_dict(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iam_policy_pb2.TestIamPermissionsResponse() + + response = client.test_iam_permissions( + request={ + "resource": "resource_value", + "permissions": ["permissions_value"], + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_test_iam_permissions_from_dict_async(): + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse() + ) + + response = await client.test_iam_permissions( + request={ + "resource": "resource_value", + "permissions": ["permissions_value"], + } + ) + call.assert_called() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (ModelGardenServiceClient, transports.ModelGardenServiceGrpcTransport), + ( + ModelGardenServiceAsyncClient, + transports.ModelGardenServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py index 860ae458b6..e761e91c0a 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py @@ -2277,6 +2277,7 @@ def test_resume_schedule_flattened(): # using the keyword arguments to the method. client.resume_schedule( name="name_value", + catch_up=True, ) # Establish that the underlying call was made with the expected @@ -2286,6 +2287,9 @@ def test_resume_schedule_flattened(): arg = args[0].name mock_val = "name_value" assert arg == mock_val + arg = args[0].catch_up + mock_val = True + assert arg == mock_val def test_resume_schedule_flattened_error(): @@ -2299,6 +2303,7 @@ def test_resume_schedule_flattened_error(): client.resume_schedule( schedule_service.ResumeScheduleRequest(), name="name_value", + catch_up=True, ) @@ -2318,6 +2323,7 @@ async def test_resume_schedule_flattened_async(): # using the keyword arguments to the method. response = await client.resume_schedule( name="name_value", + catch_up=True, ) # Establish that the underlying call was made with the expected @@ -2327,6 +2333,9 @@ async def test_resume_schedule_flattened_async(): arg = args[0].name mock_val = "name_value" assert arg == mock_val + arg = args[0].catch_up + mock_val = True + assert arg == mock_val @pytest.mark.asyncio @@ -2341,6 +2350,7 @@ async def test_resume_schedule_flattened_error_async(): await client.resume_schedule( schedule_service.ResumeScheduleRequest(), name="name_value", + catch_up=True, ) diff --git a/vertexai/__init__.py b/vertexai/__init__.py new file mode 100644 index 0000000000..5eff2b4391 --- /dev/null +++ b/vertexai/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The vertexai module.""" + +from google.cloud.aiplatform import init + +__all__ = [ + "init", +] diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py new file mode 100644 index 0000000000..0f1c3569ab --- /dev/null +++ b/vertexai/language_models/_language_models.py @@ -0,0 +1,862 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Classes for working with language models.""" + +import dataclasses +import tempfile +from typing import Any, List, Optional, Sequence, Type, Union + +from google.cloud import storage + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer as aiplatform_initializer +from google.cloud.aiplatform import models as aiplatform_models +from google.cloud.aiplatform import utils as aiplatform_utils +from google.cloud.aiplatform.preview import _publisher_model +from google.cloud.aiplatform.utils import gcs_utils + +try: + import pandas +except ImportError: + pandas = None + + +_LOGGER = base.Logger(__name__) + +_TEXT_GENERATION_TUNING_PIPELINE_URI = "https://us-kfp.pkg.dev/vertex-ai/large-language-model-pipelines/tune-large-model/preview" + +# Endpoint label/metadata key to preserve the base model ID information +_TUNING_BASE_MODEL_ID_LABEL_KEY = "google-vertex-llm-tuning-base-model-id" + +_LLM_TEXT_GENERATION_INSTANCE_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml" +) +_LLM_CHAT_GENERATION_INSTANCE_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml" +) +_LLM_TEXT_EMBEDDING_INSTANCE_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml" +) + + +@dataclasses.dataclass +class _ModelInfo: + endpoint_name: str + interface_class: Type["_LanguageModel"] + tuning_pipeline_uri: Optional[str] = None + tuning_model_id: Optional[str] = None + + +def _get_model_info(model_id: str) -> _ModelInfo: + """Gets the model information by model ID.""" + + # The default publisher is Google + if "/" not in model_id: + model_id = "publishers/google/models/" + model_id + + publisher_model_res = ( + _publisher_model._PublisherModel( # pylint: disable=protected-access + resource_name=model_id + )._gca_resource + ) + + if not publisher_model_res.name.startswith("publishers/google/models/"): + raise ValueError( + f"Only Google models are currently supported. {publisher_model_res.name}" + ) + short_model_id = publisher_model_res.name.rsplit("/", 1)[-1] + + # == "projects/{project}/locations/{location}/publishers/google/models/text-bison@001" + publisher_model_template = publisher_model_res.publisher_model_template.replace( + "{user-project}", "{project}" + ) + if not publisher_model_template: + raise RuntimeError( + f"The model does not have an associated Publisher Model. {publisher_model_res.name}" + ) + + endpoint_name = publisher_model_template.format( + project=aiplatform_initializer.global_config.project, + location=aiplatform_initializer.global_config.location, + ) + if short_model_id == "text-bison": + tuning_pipeline_uri = _TEXT_GENERATION_TUNING_PIPELINE_URI + tuning_model_id = short_model_id + "-" + publisher_model_res.version_id + else: + tuning_pipeline_uri = None + tuning_model_id = None + + interface_class_map = { + _LLM_TEXT_GENERATION_INSTANCE_SCHEMA_URI: TextGenerationModel, + _LLM_CHAT_GENERATION_INSTANCE_SCHEMA_URI: ChatModel, + _LLM_TEXT_EMBEDDING_INSTANCE_SCHEMA_URI: TextEmbeddingModel, + } + + interface_class = interface_class_map.get( + publisher_model_res.predict_schemata.instance_schema_uri + ) + + if not interface_class: + raise ValueError(f"Unknown model {publisher_model_res.name}") + + return _ModelInfo( + endpoint_name=endpoint_name, + interface_class=interface_class, + tuning_pipeline_uri=tuning_pipeline_uri, + tuning_model_id=tuning_model_id, + ) + + +def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str: + """Gets the base model ID for the model ID labels used the tuned models. + + Args: + tuning_model_id: The model ID used in tuning + + Returns: + The publisher model ID + + Raises: + ValueError: If tuning model ID is unsupported + """ + if tuning_model_id.startswith("text-bison-"): + return tuning_model_id.replace( + "text-bison-", "publishers/google/models/text-bison@" + ) + raise ValueError(f"Unsupported tuning model ID {tuning_model_id}") + + +class _LanguageModel: + """_LanguageModel is a base class for all language models.""" + + def __init__(self, model_id: str, endpoint_name: Optional[str] = None): + """Creates a LanguageModel. + + This constructor should not be called directly. + Use `LanguageModel.from_pretrained(model_name=...)` instead. + + Args: + model_id: Identifier of a Vertex LLM. Example: "text-bison@001" + endpoint_name: Vertex Endpoint resource name for the model + """ + self._model_id = model_id + self._endpoint_name = endpoint_name + # TODO(b/280879204) + # A workaround for not being able to directly instantiate the + # high-level Endpoint with the PublisherModel resource name. + self._endpoint = aiplatform.Endpoint._construct_sdk_resource_from_gapic( + aiplatform_models.gca_endpoint_compat.Endpoint(name=endpoint_name) + ) + + @classmethod + def from_pretrained(cls, model_name: str) -> "_LanguageModel": + """Loads a LanguageModel. + + Args: + model_name: Name of the model. + + Returns: + An instance of a class derieved from `_LanguageModel`. + + Raises: + ValueError: If model_name is unknown. + ValueError: If model does not support this class. + """ + model_info = _get_model_info(model_id=model_name) + + if not issubclass(model_info.interface_class, cls): + raise ValueError( + f"{model_name} is of type {model_info.interface_class.__name__} not of type {cls.__name__}" + ) + + return model_info.interface_class( + model_id=model_name, + endpoint_name=model_info.endpoint_name, + ) + + def list_tuned_model_names(self) -> Sequence[str]: + """Lists the names of tuned models. + + Returns: + A list of tuned models that can be used with the `get_tuned_model` method. + """ + model_info = _get_model_info(model_id=self._model_id) + return _list_tuned_model_names(model_id=model_info.tuning_model_id) + + @staticmethod + def get_tuned_model(tuned_model_name: str) -> "_LanguageModel": + """Loads the specified tuned language model.""" + + tuned_vertex_model = aiplatform.Model(tuned_model_name) + tuned_model_deployments = tuned_vertex_model.gca_resource.deployed_models + if len(tuned_model_deployments) == 0: + # Deploying the model + endpoint_name = tuned_vertex_model.deploy().resource_name + else: + endpoint_name = tuned_model_deployments[0].endpoint + + tuning_model_id = tuned_vertex_model.labels[_TUNING_BASE_MODEL_ID_LABEL_KEY] + base_model_id = _get_model_id_from_tuning_model_id(tuning_model_id) + model_info = _get_model_info(model_id=base_model_id) + model = model_info.interface_class( + model_id=base_model_id, + endpoint_name=endpoint_name, + ) + return model + + def tune_model( + self, + training_data: Union[str, "pandas.core.frame.DataFrame"], + *, + train_steps: int = 1000, + tuning_job_location: Optional[str] = None, + tuned_model_location: Optional[str] = None, + model_display_name: Optional[str] = None, + ): + """Tunes a model based on training data. + + This method launches a model tuning job that can take some time. + + Args: + training_data: A Pandas DataFrame of a URI pointing to data in JSON lines format. + The dataset must have the "input_text" and "output_text" columns. + train_steps: Number of training steps to perform. + tuning_job_location: GCP location where the tuning job should be run. Only "europe-west4" is supported for now. + tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now. + model_display_name: Custom display name for the tuned model. + + Returns: + A `LanguageModelTuningJob` object that represents the tuning job. + Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object. + + Raises: + ValueError: If the "tuning_job_location" value is not supported + ValueError: If the "tuned_model_location" value is not supported + RuntimeError: If the model does not support tuning + """ + if tuning_job_location != _TUNING_LOCATION: + raise ValueError( + f'Tuning is only supported in the following locations: tuning_job_location="{_TUNING_LOCATION}"' + ) + if tuned_model_location != _TUNED_MODEL_LOCATION: + raise ValueError( + f'Model deployment is only supported in the following locations: tuned_model_location="{_TUNED_MODEL_LOCATION}"' + ) + model_info = _get_model_info(model_id=self._model_id) + if not model_info.tuning_pipeline_uri: + raise RuntimeError(f"The {self._model_id} model does not support tuning") + pipeline_job = _launch_tuning_job( + training_data=training_data, + train_steps=train_steps, + model_id=model_info.tuning_model_id, + tuning_pipeline_uri=model_info.tuning_pipeline_uri, + model_display_name=model_display_name, + ) + + job = _LanguageModelTuningJob( + base_model=self, + job=pipeline_job, + ) + self._job = job + tuned_model = job.result() + # The UXR study attendees preferred to tune model in place + self._endpoint = tuned_model._endpoint + + +@dataclasses.dataclass +class TextGenerationResponse: + """TextGenerationResponse represents a response of a language model.""" + + text: str + _prediction_response: Any + + def __repr__(self): + return self.text + + +class TextGenerationModel(_LanguageModel): + """TextGenerationModel represents a general language model. + + Examples: + + # Getting answers: + model = TextGenerationModel.from_pretrained("text-bison@001") + model.predict("What is life?") + """ + + _DEFAULT_TEMPERATURE = 0.0 + _DEFAULT_MAX_OUTPUT_TOKENS = 128 + _DEFAULT_TOP_P = 0.95 + _DEFAULT_TOP_K = 40 + + def predict( + self, + prompt: str, + *, + max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS, + temperature: float = _DEFAULT_TEMPERATURE, + top_k: int = _DEFAULT_TOP_K, + top_p: float = _DEFAULT_TOP_P, + ) -> "TextGenerationResponse": + """Gets model response for a single prompt. + + Args: + prompt: Question to ask the model. + max_output_tokens: Max length of the output text in tokens. + temperature: Controls the randomness of predictions. Range: [0, 1]. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + + Returns: + A `TextGenerationResponse` object that contains the text produced by the model. + """ + + return self._batch_predict( + prompts=[prompt], + max_output_tokens=max_output_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + )[0] + + def _batch_predict( + self, + prompts: List[str], + max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS, + temperature: float = _DEFAULT_TEMPERATURE, + top_k: int = _DEFAULT_TOP_K, + top_p: float = _DEFAULT_TOP_P, + ) -> List["TextGenerationResponse"]: + """Gets model response for a single prompt. + + Args: + prompts: Questions to ask the model. + max_output_tokens: Max length of the output text in tokens. + temperature: Controls the randomness of predictions. Range: [0, 1]. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + + Returns: + A list of `TextGenerationResponse` objects that contain the texts produced by the model. + """ + instances = [{"content": str(prompt)} for prompt in prompts] + prediction_parameters = { + "temperature": temperature, + "maxDecodeSteps": max_output_tokens, + "topP": top_p, + "topK": top_k, + } + + prediction_response = self._endpoint.predict( + instances=instances, + parameters=prediction_parameters, + ) + + return [ + TextGenerationResponse( + text=prediction["content"], + _prediction_response=prediction_response, + ) + for prediction in prediction_response.predictions + ] + + +class _ChatModel(TextGenerationModel): + """ChatModel represents a language model that is capable of chat. + + Examples: + + # Getting answers: + model = ChatModel.from_pretrained("chat-bison@001") + model.predict("What is life?") + + # Chat: + chat = model.start_chat() + + chat.send_message("Do you know any cool events this weekend?") + """ + + def start_chat( + self, + max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE, + top_k: int = TextGenerationModel._DEFAULT_TOP_K, + top_p: float = TextGenerationModel._DEFAULT_TOP_P, + ) -> "_ChatSession": + """Starts a chat session with the model. + + Args: + max_output_tokens: Max length of the output text in tokens. + temperature: Controls the randomness of predictions. Range: [0, 1]. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + + Returns: + A `ChatSession` object. + """ + return _ChatSession( + model=self, + max_output_tokens=max_output_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + +class _ChatSession: + """ChatSession represents a chat session with a language model. + + Within a chat session, the model keeps context and remembers the previous conversation. + """ + + def __init__( + self, + model: _ChatModel, + max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE, + top_k: int = TextGenerationModel._DEFAULT_TOP_K, + top_p: float = TextGenerationModel._DEFAULT_TOP_P, + ): + self._model = model + self._history = [] + self._history_text = "" + self._max_output_tokens = max_output_tokens + self._temperature = temperature + self._top_k = top_k + self._top_p = top_p + + def send_message( + self, + message: str, + *, + max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE, + top_k: int = TextGenerationModel._DEFAULT_TOP_K, + top_p: float = TextGenerationModel._DEFAULT_TOP_P, + ) -> "TextGenerationResponse": + """Sends message to the language model and gets a response. + + Args: + message: Message to send to the model + max_output_tokens: Max length of the output text in tokens. + temperature: Controls the randomness of predictions. Range: [0, 1]. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + + Returns: + A `TextGenerationResponse` object that contains the text produced by the model. + """ + new_history_text = "" + if self._history_text: + new_history_text = self._history_text.rstrip("\n") + "\n\n" + new_history_text += message.rstrip("\n") + "\n" + + response_obj = self._model.predict( + prompt=new_history_text, + max_output_tokens=max_output_tokens or self._max_output_tokens, + temperature=temperature or self._temperature, + top_k=top_k or self._top_k, + top_p=top_p or self._top_p, + ) + response_text = response_obj.text + + self._history.append((message, response_text)) + new_history_text += response_text.rstrip("\n") + "\n" + self._history_text = new_history_text + return response_obj + + +class TextEmbeddingModel(_LanguageModel): + """TextEmbeddingModel converts text into a vector of floating-point numbers. + + Examples: + + # Getting embedding: + model = TextEmbeddingModel.from_pretrained("embedding-gecko@001") + embeddings = model.get_embeddings(["What is life?"]) + for embedding in embeddings: + vector = embedding.values + print(len(vector)) + """ + + def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]: + instances = [{"content": str(text)} for text in texts] + + prediction_response = self._endpoint.predict( + instances=instances, + ) + + return [ + TextEmbedding( + values=prediction["embeddings"]["values"], + _prediction_response=prediction_response, + ) + for prediction in prediction_response.predictions + ] + + +class TextEmbedding: + """Contains text embedding vector.""" + + def __init__( + self, + values: List[float], + _prediction_response: Any = None, + ): + self.values = values + self._prediction_response = _prediction_response + + +@dataclasses.dataclass +class InputOutputTextPair: + """InputOutputTextPair represents a pair of input and output texts.""" + + input_text: str + output_text: str + + +class ChatModel(_LanguageModel): + """ChatModel represents a language model that is capable of chat. + + Examples: + + chat_model = ChatModel.from_pretrained("chat-bison@001") + + chat = chat_model.start_chat( + context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.", + examples=[ + InputOutputTextPair( + input_text="Who do you work for?", + output_text="I work for Ned.", + ), + InputOutputTextPair( + input_text="What do I like?", + output_text="Ned likes watching movies.", + ), + ], + temperature=0.3, + ) + + chat.send_message("Do you know any cool events this weekend?") + """ + + def start_chat( + self, + *, + context: Optional[str] = None, + examples: Optional[List[InputOutputTextPair]] = None, + max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE, + top_k: int = TextGenerationModel._DEFAULT_TOP_K, + top_p: float = TextGenerationModel._DEFAULT_TOP_P, + ) -> "ChatSession": + """Starts a chat session with the model. + + Args: + context: Context shapes how the model responds throughout the conversation. + For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style + examples: List of structured messages to the model to learn how to respond to the conversation. + A list of `InputOutputTextPair` objects. + max_output_tokens: Max length of the output text in tokens. + temperature: Controls the randomness of predictions. Range: [0, 1]. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40] + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + + Returns: + A `ChatSession` object. + """ + return ChatSession( + model=self, + context=context, + examples=examples, + max_output_tokens=max_output_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + +class ChatSession: + """ChatSession represents a chat session with a language model. + + Within a chat session, the model keeps context and remembers the previous conversation. + """ + + def __init__( + self, + model: ChatModel, + context: Optional[str] = None, + examples: Optional[List[InputOutputTextPair]] = None, + max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE, + top_k: int = TextGenerationModel._DEFAULT_TOP_K, + top_p: float = TextGenerationModel._DEFAULT_TOP_P, + ): + self._model = model + self._context = context + self._examples = examples + self._history = [] + self._max_output_tokens = max_output_tokens + self._temperature = temperature + self._top_k = top_k + self._top_p = top_p + + def send_message( + self, + message: str, + *, + max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE, + top_k: int = TextGenerationModel._DEFAULT_TOP_K, + top_p: float = TextGenerationModel._DEFAULT_TOP_P, + ) -> "TextGenerationResponse": + """Sends message to the language model and gets a response. + + Args: + message: Message to send to the model + max_output_tokens: Max length of the output text in tokens. + temperature: Controls the randomness of predictions. Range: [0, 1]. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + + Returns: + A `TextGenerationResponse` object that contains the text produced by the model. + """ + prediction_parameters = { + "temperature": temperature, + "maxDecodeSteps": max_output_tokens, + "topP": top_p, + "topK": top_k, + } + messages = [] + for input_text, output_text in self._history: + messages.append( + { + "author": "user", + "content": input_text, + } + ) + messages.append( + { + "author": "bot", + "content": output_text, + } + ) + + messages.append( + { + "author": "user", + "content": message, + } + ) + + prediction_instance = {"messages": messages} + if self._context: + prediction_instance["context"] = self._context + if self._examples: + prediction_instance["examples"] = [ + { + "input": {"content": example.input_text}, + "output": {"content": example.output_text}, + } + for example in self._examples + ] + + prediction_response = self._model._endpoint.predict( + instances=[prediction_instance], + parameters=prediction_parameters, + ) + + response_obj = TextGenerationResponse( + text=prediction_response.predictions[0]["candidates"][0]["content"], + _prediction_response=prediction_response, + ) + response_text = response_obj.text + + self._history.append((message, response_text)) + return response_obj + + +###### Model tuning +# Currently, tuning can only work in this location +_TUNING_LOCATION = "europe-west4" +# Currently, deployment can only work in this location +_TUNED_MODEL_LOCATION = "us-central1" + + +class _LanguageModelTuningJob: + """LanguageModelTuningJob represents a fine-tuning job.""" + + def __init__( + self, + base_model: _LanguageModel, + job: aiplatform.PipelineJob, + ): + self._base_model = base_model + self._job = job + self._model: Optional[_LanguageModel] = None + + def result(self) -> "_LanguageModel": + """Blocks until the tuning is complete and returns a `LanguageModel` object.""" + if self._model: + return self._model + self._job.wait() + upload_model_tasks = [ + task_info + for task_info in self._job.gca_resource.job_detail.task_details + if task_info.task_name == "upload-llm-model" + ] + if len(upload_model_tasks) != 1: + raise RuntimeError( + f"Failed to get the model name from the tuning pipeline: {self._job.name}" + ) + upload_model_task = upload_model_tasks[0] + + # Trying to get model name from output parameter + vertex_model_name = upload_model_task.execution.metadata[ + "output:model_resource_name" + ].strip() + _LOGGER.info(f"Tuning has completed. Created Vertex Model: {vertex_model_name}") + self._model = type(self._base_model).get_tuned_model( + tuned_model_name=vertex_model_name + ) + return self._model + + @property + def status(self): + """Job status""" + return self._job.state + + def cancel(self): + self._job.cancel() + + +def _get_tuned_models_dir_uri(model_id: str) -> str: + staging_gcs_bucket = ( + gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist() + ) + return ( + staging_gcs_bucket.replace("/output_artifacts/", "/tuned_language_models/") + + model_id + ) + + +def _list_tuned_model_names(model_id: str) -> List[str]: + tuned_models = aiplatform.Model.list( + filter=f'labels.{_TUNING_BASE_MODEL_ID_LABEL_KEY}="{model_id}"', + # TODO(b/275444096): Remove the explicit location once models are deployed to the user's selected location + location=_TUNED_MODEL_LOCATION, + ) + model_names = [model.resource_name for model in tuned_models] + return model_names + + +def _generate_tuned_model_dir_uri(model_id: str) -> str: + tuned_model_id = "tuned_model_" + aiplatform_utils.timestamped_unique_name() + tuned_models_dir_uri = _get_tuned_models_dir_uri(model_id=model_id) + tuned_model_dir_uri = _uri_join(tuned_models_dir_uri, tuned_model_id) + return tuned_model_dir_uri + + +def _launch_tuning_job( + training_data: Union[str, "pandas.core.frame.DataFrame"], + model_id: str, + tuning_pipeline_uri: str, + train_steps: Optional[int] = None, + model_display_name: Optional[str] = None, +) -> aiplatform.PipelineJob: + output_dir_uri = _generate_tuned_model_dir_uri(model_id=model_id) + if isinstance(training_data, str): + dataset_uri = training_data + elif pandas and isinstance(training_data, pandas.DataFrame): + dataset_uri = _uri_join(output_dir_uri, "training_data.jsonl") + + with tempfile.NamedTemporaryFile() as temp_file: + dataset_path = temp_file.name + df = training_data + df = df[["input_text", "output_text"]] + df.to_json(path_or_buf=dataset_path, orient="records", lines=True) + storage_client = storage.Client( + credentials=aiplatform_initializer.global_config.credentials + ) + storage.Blob.from_string( + uri=dataset_uri, client=storage_client + ).upload_from_filename(filename=dataset_path) + else: + raise TypeError(f"Unsupported training_data type: {type(training_data)}") + + job = _launch_tuning_job_on_jsonl_data( + model_id=model_id, + dataset_name_or_uri=dataset_uri, + train_steps=train_steps, + tuning_pipeline_uri=tuning_pipeline_uri, + model_display_name=model_display_name, + ) + return job + + +def _launch_tuning_job_on_jsonl_data( + model_id: str, + dataset_name_or_uri: str, + tuning_pipeline_uri: str, + train_steps: Optional[int] = None, + model_display_name: Optional[str] = None, +) -> aiplatform.PipelineJob: + if not model_display_name: + # Creating a human-readable model display name + name = f"{model_id} tuned for {train_steps} steps on " + # Truncating the start of the dataset URI to keep total length <= 128. + max_display_name_length = 128 + if len(dataset_name_or_uri + name) <= max_display_name_length: + name += dataset_name_or_uri + else: + name += "..." + remaining_length = max_display_name_length - len(name) + name += dataset_name_or_uri[-remaining_length:] + model_display_name = name[:max_display_name_length] + + pipeline_arguments = { + "train_steps": train_steps, + "project": aiplatform_initializer.global_config.project, + # TODO(b/275444096): Remove the explicit location once tuning can happen in all regions + # "location": aiplatform_initializer.global_config.location, + "location": _TUNED_MODEL_LOCATION, + "large_model_reference": model_id, + "model_display_name": model_display_name, + } + + if dataset_name_or_uri.startswith("projects/"): + pipeline_arguments["dataset_name"] = dataset_name_or_uri + if dataset_name_or_uri.startswith("gs://"): + pipeline_arguments["dataset_uri"] = dataset_name_or_uri + job = aiplatform.PipelineJob( + template_path=tuning_pipeline_uri, + display_name=None, + parameter_values=pipeline_arguments, + # TODO(b/275444101): Remove the explicit location once model can be deployed in all regions + location=_TUNING_LOCATION, + ) + job.submit() + return job + + +def _uri_join(uri: str, path_fragment: str) -> str: + """Appends path fragment to URI. + + urllib.parse.urljoin only works on URLs, not URIs + """ + + return uri.rstrip("/") + "/" + path_fragment.lstrip("/") diff --git a/vertexai/preview/language_models.py b/vertexai/preview/language_models.py new file mode 100644 index 0000000000..285495e7c5 --- /dev/null +++ b/vertexai/preview/language_models.py @@ -0,0 +1,35 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Classes for working with language models.""" + +from vertexai.language_models._language_models import ( + ChatModel, + ChatSession, + InputOutputTextPair, + TextEmbedding, + TextEmbeddingModel, + TextGenerationModel, + TextGenerationResponse, +) + +__all__ = [ + "ChatModel", + "ChatSession", + "InputOutputTextPair", + "TextEmbedding", + "TextEmbeddingModel", + "TextGenerationModel", + "TextGenerationResponse", +]