From 5eb885ee7e01eece15679ce400f222930da1ac16 Mon Sep 17 00:00:00 2001 From: Yeesian Ng Date: Thu, 2 May 2024 18:02:44 -0700 Subject: [PATCH 01/30] feat: Add support for BaseModels in LangChain templates PiperOrigin-RevId: 630232213 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 24 -- google/cloud/aiplatform/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1/gapic_version.py | 2 +- .../v1/schema/predict/params/gapic_version.py | 2 +- .../schema/predict/params_v1/gapic_version.py | 2 +- .../predict/prediction/gapic_version.py | 2 +- .../predict/prediction_v1/gapic_version.py | 2 +- .../trainingjob/definition/gapic_version.py | 2 +- .../definition_v1/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1beta1/gapic_version.py | 2 +- .../schema/predict/params/gapic_version.py | 2 +- .../predict/params_v1beta1/gapic_version.py | 2 +- .../predict/prediction/gapic_version.py | 2 +- .../prediction_v1beta1/gapic_version.py | 2 +- .../trainingjob/definition/gapic_version.py | 2 +- .../definition_v1beta1/gapic_version.py | 2 +- google/cloud/aiplatform/version.py | 2 +- google/cloud/aiplatform_v1/gapic_version.py | 2 +- .../cloud/aiplatform_v1beta1/gapic_version.py | 2 +- pypi/_vertex_ai_placeholder/version.py | 2 +- ...t_metadata_google.cloud.aiplatform.v1.json | 2 +- ...adata_google.cloud.aiplatform.v1beta1.json | 2 +- setup.py | 4 +- ...st_reasoning_engine_templates_langchain.py | 101 ++---- .../reasoning_engines/templates/langchain.py | 294 ++++++++---------- 28 files changed, 183 insertions(+), 288 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 2a7042a5dd..94124d32ea 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.50.0" + ".": "1.49.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 683a1d2ec5..53e1884e47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,29 +1,5 @@ # Changelog -## [1.50.0](https://github.com/googleapis/python-aiplatform/compare/v1.49.0...v1.50.0) (2024-05-02) - - -### Features - -* Add `Candidate.grounding_metadata` property ([b22a8b8](https://github.com/googleapis/python-aiplatform/commit/b22a8b847e3b299b828e37405e3678093486de28)) -* Add option to not include time_series_metrics in get_experiment_df call. This will improve execution time for Experiments with large number of runs. ([78a95c5](https://github.com/googleapis/python-aiplatform/commit/78a95c52d0e7bd9ec5b656ce67044b2f01677156)) -* Add tune_model and deploy_tuned_model for TextEmbeddingModel. ([42f5d6f](https://github.com/googleapis/python-aiplatform/commit/42f5d6f7cd13d51c4a73113c59e8b3c728cfc08b)) -* Automatically populate parents for full resource name in Vertex RAG SDK ([26657ff](https://github.com/googleapis/python-aiplatform/commit/26657ffd25ecb91882ca764e513c2e952833257f)) -* Deploy a tuned text embedding model -- it doesn't matter, if it's tuned using Node.js, or curl. ([8ca9cdf](https://github.com/googleapis/python-aiplatform/commit/8ca9cdf3576e3ce3b373ace4cd6ab0e9c54aa9f2)) -* Make get_embeddings work both for foundational & tuned models. ([b8b589c](https://github.com/googleapis/python-aiplatform/commit/b8b589ce9fff29d1721450d32b4a84a7f69413c3)) -* Python SDK for Vertex Model Monitoring V2. ([021d59f](https://github.com/googleapis/python-aiplatform/commit/021d59f1487e4e16c847d4135899d6845c0210aa)) -* Support public endpoint for Ray Client ([57a5f78](https://github.com/googleapis/python-aiplatform/commit/57a5f7815ffb8523e91d900da4ff7cfd0c344fe4)) - - -### Bug Fixes - -* Add deprecation warnings when using Ray v2.4 ([3a36784](https://github.com/googleapis/python-aiplatform/commit/3a367843840513e3257610c8ab38e9f79d3bcea0)) -* Append allowed_plugins in tb-gcp-uploader to default allowed plugins ([aab9c3e](https://github.com/googleapis/python-aiplatform/commit/aab9c3e41b92a1d60090e3d1d594390a5e9f3ff6)) -* LLM - Added missing parameters to the no-op `_TunableTextEmbeddingModelMixin.get_tuned_model` method ([eb05ac4](https://github.com/googleapis/python-aiplatform/commit/eb05ac421f186441a92c6e3b6a010d74caf14782)) -* LVM - Fixed the typo in the VisionModel aspect ratio type annotation ([2d19137](https://github.com/googleapis/python-aiplatform/commit/2d1913773cf9f4a4f8a2c8c8f45680c3ea97f68e)) -* Move torch import ([e6d34df](https://github.com/googleapis/python-aiplatform/commit/e6d34df7da7508c655eb17ee694e1ab2160fc8aa)) -* Ray - Fixed exception when using Ray 2.4 ([2661f52](https://github.com/googleapis/python-aiplatform/commit/2661f52fd08169e5d29b58f2afce9702b30101ae)) - ## [1.49.0](https://github.com/googleapis/python-aiplatform/compare/v1.48.0...v1.49.0) (2024-04-27) diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index 1af552a3c1..5f3bd3ab54 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.50.0" +__version__ = "1.49.0" diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index 39995f175a..41b5da9439 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.49.0" # {x-release-please-version} diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py index de799ba97d..78cddd4821 100644 --- a/pypi/_vertex_ai_placeholder/version.py +++ b/pypi/_vertex_ai_placeholder/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.50.0" +__version__ = "1.49.0" diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 39acefe786..5234f5f287 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.50.0" + "version": "1.49.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 79deba42e4..27545fcbc8 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.50.0" + "version": "1.49.0" }, "snippets": [ { diff --git a/setup.py b/setup.py index a3a0778f00..ca04338e62 100644 --- a/setup.py +++ b/setup.py @@ -149,9 +149,9 @@ ] langchain_extra_require = [ - "langchain >= 0.1.13, < 0.2", + "langchain >= 0.1.16, < 0.2", "langchain-core < 0.2", - "langchain-google-vertexai < 0.2", + "langchain-google-vertexai < 2", ] langchain_testing_extra_require = list( diff --git a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py index 21bebc158f..c4f7d7fdda 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py +++ b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py @@ -13,7 +13,6 @@ # limitations under the License. # import importlib -import json from typing import Optional from unittest import mock @@ -25,11 +24,9 @@ import pytest -from langchain_core import agents -from langchain_core import messages -from langchain_core import outputs -from langchain_core import tools as lc_tools +from langchain_core import prompts from langchain.load import dump as langchain_load_dump +from langchain.agents.format_scratchpad import format_to_openai_function_messages from langchain.tools.base import StructuredTool @@ -93,6 +90,18 @@ def setup_method(self): project=_TEST_PROJECT, location=_TEST_LOCATION, ) + self.prompt = { + "input": lambda x: x["input"], + "agent_scratchpad": ( + lambda x: format_to_openai_function_messages(x["intermediate_steps"]) + ), + } | prompts.ChatPromptTemplate.from_messages( + [ + ("user", "{input}"), + prompts.MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + self.output_parser = mock.Mock() def teardown_method(self): initializer.global_pool.shutdown(wait=True) @@ -105,24 +114,33 @@ def test_initialization(self): assert agent._runnable is None def test_initialization_with_tools(self): + tools = [ + place_tool_query, + StructuredTool.from_function(place_photo_query), + ] agent = reasoning_engines.LangchainAgent( model=_TEST_MODEL, - tools=[ - place_tool_query, - StructuredTool.from_function(place_photo_query), - ], + tools=tools, ) - for tool in agent._tools: - assert isinstance(tool, lc_tools.BaseTool) + for tool, agent_tool in zip(tools, agent._tools): + assert isinstance(agent_tool, type(tool)) def test_set_up(self, vertexai_init_mock): - agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL) + agent = reasoning_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + ) assert agent._runnable is None agent.set_up() assert agent._runnable is not None def test_query(self, langchain_dump_mock): - agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL) + agent = reasoning_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + ) agent._runnable = mock.Mock() mocks = mock.Mock() mocks.attach_mock(mock=agent._runnable, attribute="invoke") @@ -132,63 +150,6 @@ def test_query(self, langchain_dump_mock): ) -class TestDefaultOutputParser: - def test_parse_result_function_call(self, vertexai_init_mock): - agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL) - agent.set_up() - tool_input = { - "photo_reference": "abcd1234", - "maxwidth": _DEFAULT_PLACE_PHOTO_MAXWIDTH, - } - result = agent._output_parser.parse_result( - [ - outputs.ChatGeneration( - message=messages.AIMessage( - content="", - additional_kwargs={ - "function_call": { - "name": "place_tool_query", - "arguments": json.dumps(tool_input), - }, - }, - ) - ) - ] - ) - assert isinstance(result, agents.AgentActionMessageLog) - assert result.tool == "place_tool_query" - assert result.tool_input == tool_input - - def test_parse_result_not_function_call(self, vertexai_init_mock): - agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL) - agent.set_up() - content = "test content" - result = agent._output_parser.parse_result( - [ - outputs.ChatGeneration( - message=messages.AIMessage(content=content), - ) - ] - ) - assert isinstance(result, agents.AgentFinish) - assert result.return_values == {"output": content} - assert result.log == content - - -class TestDefaultOutputParserErrors: - def test_parse_result_non_chat_generation_errors(self, vertexai_init_mock): - agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL) - agent.set_up() - with pytest.raises(ValueError, match=r"only works on ChatGeneration"): - agent._output_parser.parse_result(["text"]) - - def test_parse_text_errors(self, vertexai_init_mock): - agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL) - agent.set_up() - with pytest.raises(ValueError, match=r"Can only parse messages"): - agent._output_parser.parse("text") - - class TestConvertToolsOrRaise: def test_convert_tools_or_raise(self, vertexai_init_mock): pass diff --git a/vertexai/preview/reasoning_engines/templates/langchain.py b/vertexai/preview/reasoning_engines/templates/langchain.py index 8354d71a20..a8e047b1de 100644 --- a/vertexai/preview/reasoning_engines/templates/langchain.py +++ b/vertexai/preview/reasoning_engines/templates/langchain.py @@ -31,13 +31,16 @@ try: from langchain_core import runnables from langchain_core import tools as lc_tools + from langchain_core.language_models import base as lc_language_models BaseTool = lc_tools.BaseTool + BaseLanguageModel = lc_language_models.BaseLanguageModel GetSessionHistoryCallable = runnables.history.GetSessionHistoryCallable RunnableConfig = runnables.RunnableConfig RunnableSerializable = runnables.RunnableSerializable except ImportError: BaseTool = Any + BaseLanguageModel = Any GetSessionHistoryCallable = Any RunnableConfig = Any RunnableSerializable = Any @@ -62,126 +65,84 @@ def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]: def _default_output_parser(): - from langchain_core import agents - from langchain_core import output_parsers - from langchain_core import outputs - - class DefaultOutputParser(output_parsers.BaseOutputParser): - - def parse_result( - self, - result: List[outputs.Generation], - ) -> Union[agents.AgentAction, agents.AgentFinish]: - if not isinstance(result[0], outputs.ChatGeneration): - raise ValueError( - "This output parser only works on ChatGeneration output" - ) - msg = result[0].message - content = msg.content - function_call = msg.additional_kwargs.get("function_call", {}) - if function_call: - function_name = function_call["name"] - tool_input = json.loads(function_call.get("arguments", {})) - content_msg = f"responded: {content}\n" if content else "\n" - log_msg = ( - f"\nInvoking: `{function_name}` with `{tool_input}`\n" - f"{content_msg}\n" - ) - return agents.AgentActionMessageLog( - tool=function_name, - tool_input=tool_input, - log=log_msg, - message_log=[msg], - ) - return agents.AgentFinish( - return_values={"output": content}, - log=str(content), - ) - - def parse( - self, - text: str, - ) -> Union[agents.AgentAction, agents.AgentFinish]: - raise ValueError("Can only parse messages") - - return DefaultOutputParser() + from langchain.agents.output_parsers.tools import ToolsAgentOutputParser + return ToolsAgentOutputParser() + + +def _default_model_builder( + model_name: str, + *, + project: str, + location: str, + model_kwargs: Optional[Mapping[str, Any]] = None, +) -> "BaseLanguageModel": + import vertexai + from google.cloud.aiplatform import initializer + from langchain_google_vertexai import ChatVertexAI + + model_kwargs = model_kwargs or {} + current_project = initializer.global_config.project + current_location = initializer.global_config.location + vertexai.init(project=project, location=location) + model = ChatVertexAI(model_name=model_name, **model_kwargs) + vertexai.init(project=current_project, location=current_location) + return model + + +def _default_runnable_builder( + model: "BaseLanguageModel", + *, + tools: Optional[Sequence[Union[Callable, "BaseTool"]]] = None, + prompt: Optional["RunnableSerializable"] = None, + output_parser: Optional["RunnableSerializable"] = None, + chat_history: Optional["GetSessionHistoryCallable"] = None, + agent_executor_kwargs: Optional[Mapping[str, Any]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, +) -> "RunnableSerializable": + from langchain_core import tools as lc_tools + from langchain.agents import AgentExecutor + from langchain.tools.base import StructuredTool + # The prompt template and runnable_kwargs needs to be customized depending + # on whether the user intends for the agent to have history. The way the + # user would reflect that is by setting chat_history (which defaults to + # None). + has_history: bool = chat_history is not None + prompt = prompt or _default_prompt(has_history) + output_parser = output_parser or _default_output_parser() + agent_executor_kwargs = agent_executor_kwargs or {} + runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history) + if tools: + model = model.bind_tools(tools=tools) + else: + tools = [] + agent_executor = AgentExecutor( + agent=prompt | model | output_parser, + tools=[ + tool if isinstance(tool, lc_tools.BaseTool) + else StructuredTool.from_function(tool) + for tool in tools + ], + **agent_executor_kwargs, + ) + if has_history: + from langchain_core.runnables.history import RunnableWithMessageHistory + return RunnableWithMessageHistory( + runnable=agent_executor, + get_session_history=chat_history, + **runnable_kwargs, + ) + return agent_executor def _default_prompt(has_history: bool) -> "RunnableSerializable": - from langchain_core import agents - from langchain_core import messages from langchain_core import prompts - - def _convert_agent_action_to_messages( - agent_action: agents.AgentAction, observation: str - ) -> List[messages.BaseMessage]: - """Convert an agent action to a message. - - This is used to reconstruct the original message from the agent action. - - Args: - agent_action (AgentAction): The action to convert into messages. - observation (str): The observation to convert into messages. - - Returns: - List[messages.BaseMessage]: A list of messages that corresponds to - the original tool invocation. - """ - if isinstance(agent_action, agents.AgentActionMessageLog): - return list(agent_action.message_log) + [ - _create_function_message(agent_action, observation) - ] - else: - return [messages.AIMessage(content=agent_action.log)] - - def _create_function_message( - agent_action: agents.AgentAction, observation: str - ) -> messages.FunctionMessage: - """Convert agent action and observation into a function message. - - Args: - agent_action (AgentAction): tool invocation request from the agent. - observation (str): the result of the tool invocation. - - Returns: - FunctionMessage: A message corresponding to the tool invocation. - """ - if not isinstance(observation, str): - try: - content = json.dumps(observation, ensure_ascii=False) - except Exception: - content = str(observation) - else: - content = observation - return messages.FunctionMessage(name=agent_action.tool, content=content) - - def _format_to_messages( - intermediate_steps: Sequence[Tuple[agents.AgentAction, str]], - ) -> List[messages.BaseMessage]: - """Convert (AgentAction, tool output) tuples into messages. - - Args: - intermediate_steps (Sequence[Tuple[AgentAction, str]]): - Required. Steps the model has taken, along with observations. - - Returns: - List[langchain_core.messages.BaseMessage]: list of messages to send - to the model for the next generation. - - """ - scratchpad_messages = [] - for agent_action, observation in intermediate_steps: - scratchpad_messages.extend( - _convert_agent_action_to_messages(agent_action, observation) - ) - return scratchpad_messages - + from langchain.agents.format_scratchpad.tools import format_to_tool_messages if has_history: return { "history": lambda x: x["history"], "input": lambda x: x["input"], "agent_scratchpad": ( - lambda x: _format_to_messages(x["intermediate_steps"]) + lambda x: format_to_tool_messages(x["intermediate_steps"]) ), } | prompts.ChatPromptTemplate.from_messages([ prompts.MessagesPlaceholder(variable_name="history"), @@ -192,7 +153,7 @@ def _format_to_messages( return { "input": lambda x: x["input"], "agent_scratchpad": ( - lambda x: _format_to_messages(x["intermediate_steps"]) + lambda x: format_to_tool_messages(x["intermediate_steps"]) ), } | prompts.ChatPromptTemplate.from_messages([ ("user", "{input}"), @@ -216,22 +177,12 @@ def _validate_callable_parameters_are_annotated(callable: Callable): ) -def _convert_tools_or_raise( - tools: Sequence[Union[Callable, "BaseTool"]] -) -> Sequence["BaseTool"]: - """Converts the tools into Langchain tools (if needed). - - See https://blog.langchain.dev/structured-tools/ for details. - """ +def _validate_tools(tools: Sequence[Union[Callable, "BaseTool"]]): + """Validates that the tools are usable for tool calling.""" from langchain_core import tools as lc_tools - from langchain.tools.base import StructuredTool - result = [] for tool in tools: if not isinstance(tool, lc_tools.BaseTool): _validate_callable_parameters_are_annotated(tool) - tool = StructuredTool.from_function(tool) - result.append(tool) - return result class LangchainAgent: @@ -253,19 +204,37 @@ def __init__( model_kwargs: Optional[Mapping[str, Any]] = None, agent_executor_kwargs: Optional[Mapping[str, Any]] = None, runnable_kwargs: Optional[Mapping[str, Any]] = None, + model_builder: Optional[Callable] = None, + runnable_builder: Optional[Callable] = None, ): """Initializes the LangchainAgent. Under-the-hood, assuming .set_up() is called, this will correspond to ``` - from langchain import agents - from langchain_core.runnables.history import RunnableWithMessageHistory - from langchain_google_vertexai import ChatVertexAI + model = model_builder(model_name=model, model_kwargs=model_kwargs) + runnable = runnable_builder( + prompt=prompt, + model=model, + tools=tools, + output_parser=output_parser, + chat_history=chat_history, + agent_executor_kwargs=agent_executor_kwargs, + runnable_kwargs=runnable_kwargs, + ) + ``` + When everything is based on their default values, this corresponds to + ``` + # model_builder + from langchain_google_vertexai import ChatVertexAI llm = ChatVertexAI(model_name=model, **model_kwargs) + + # runnable_builder + from langchain import agents + from langchain_core.runnables.history import RunnableWithMessageHistory agent_executor = agents.AgentExecutor( - agent=prompt | llm.bind(functions=tools) | output_parser, + agent=prompt | llm.bind_tools(tools=tools) | output_parser, tools=tools, **agent_executor_kwargs, ) @@ -337,6 +306,15 @@ def __init__( langchain.runnables.history.RunnableWithMessageHistory if chat_history is specified. If chat_history is None, this will be ignored. + model_builder (Callable): + Optional. Callable that returns a new language model. Defaults + to a a callable that returns ChatVertexAI based on `model`, + `model_kwargs` and the parameters in `vertexai.init`. + runnable_builder (Callable): + Optional. Callable that returns a new runnable. This can be used + for customizing the orchestration logic of the Agent based on + the model returned by `model_builder` and the rest of the input + arguments. Raises: TypeError: If there is an invalid tool (e.g. function with an input @@ -347,9 +325,10 @@ def __init__( self._location = initializer.global_config.location self._tools = [] if tools: - # Unlike the other fields, we convert tools at initialization to - # validate the functions/tools before they are deployed. - self._tools = _convert_tools_or_raise(tools) + # We validate tools at initialization for actionable feedback before + # they are deployed. + _validate_tools(tools) + self._tools = tools self._model_name = model self._prompt = prompt self._output_parser = output_parser @@ -357,8 +336,10 @@ def __init__( self._model_kwargs = model_kwargs self._agent_executor_kwargs = agent_executor_kwargs self._runnable_kwargs = runnable_kwargs + self._model = None + self._model_builder = model_builder self._runnable = None - self._chat_history_store = None + self._runnable_builder = runnable_builder def set_up(self): """Sets up the agent for execution of queries at runtime. @@ -370,46 +351,23 @@ def set_up(self): the ReasoningEngine service for deployment, as it initializes clients that can not be serialized. """ - from langchain.agents import AgentExecutor - from langchain_core.runnables.history import RunnableWithMessageHistory - from langchain_google_vertexai import ChatVertexAI - import vertexai - from google.cloud.aiplatform import initializer - - has_history = self._chat_history is not None - self._prompt = self._prompt or _default_prompt(has_history) - self._output_parser = self._output_parser or _default_output_parser() - self._model_kwargs = self._model_kwargs or {} - self._agent_executor_kwargs = self._agent_executor_kwargs or {} - self._runnable_kwargs = ( - self._runnable_kwargs or _default_runnable_kwargs(has_history) - ) - - current_project = initializer.global_config.project - current_location = initializer.global_config.location - vertexai.init(project=self._project, location=self._location) - self._llm = ChatVertexAI( + model_builder = self._model_builder or _default_model_builder + self._model = model_builder( model_name=self._model_name, - **self._model_kwargs, + model_kwargs=self._model_kwargs, + project=self._project, + location=self._location, ) - vertexai.init(project=current_project, location=current_location) - - if self._tools: - self._llm = self._llm.bind(functions=self._tools) - self._agent = self._prompt | self._llm | self._output_parser - self._agent_executor = AgentExecutor( - agent=self._agent, + runnable_builder = self._runnable_builder or _default_runnable_builder + self._runnable = runnable_builder( + prompt=self._prompt, + model=self._model, tools=self._tools, - **self._agent_executor_kwargs, + output_parser=self._output_parser, + chat_history=self._chat_history, + agent_executor_kwargs=self._agent_executor_kwargs, + runnable_kwargs=self._runnable_kwargs, ) - runnable = self._agent_executor - if has_history: - runnable = RunnableWithMessageHistory( - runnable=self._agent_executor, - get_session_history=self._chat_history, - **self._runnable_kwargs, - ) - self._runnable = runnable def query( self, From 73c9a09c3d0ef50a506dc865fd2d4aa200597628 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Fri, 3 May 2024 06:05:32 -0700 Subject: [PATCH 02/30] chore(main): release 1.50.0 Copybara import of the project: -- d42d024ed1cfd5f2c083134cfe779a2e41982900 by release-please[bot] <55107282+release-please[bot]@users.noreply.github.com>: chore(main): release 1.50.0 COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/3695 from googleapis:release-please--branches--main d42d024ed1cfd5f2c083134cfe779a2e41982900 PiperOrigin-RevId: 630368609 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 24 +++++++++++++++++++ google/cloud/aiplatform/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1/gapic_version.py | 2 +- .../v1/schema/predict/params/gapic_version.py | 2 +- .../schema/predict/params_v1/gapic_version.py | 2 +- .../predict/prediction/gapic_version.py | 2 +- .../predict/prediction_v1/gapic_version.py | 2 +- .../trainingjob/definition/gapic_version.py | 2 +- .../definition_v1/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1beta1/gapic_version.py | 2 +- .../schema/predict/params/gapic_version.py | 2 +- .../predict/params_v1beta1/gapic_version.py | 2 +- .../predict/prediction/gapic_version.py | 2 +- .../prediction_v1beta1/gapic_version.py | 2 +- .../trainingjob/definition/gapic_version.py | 2 +- .../definition_v1beta1/gapic_version.py | 2 +- google/cloud/aiplatform/version.py | 2 +- google/cloud/aiplatform_v1/gapic_version.py | 2 +- .../cloud/aiplatform_v1beta1/gapic_version.py | 2 +- pypi/_vertex_ai_placeholder/version.py | 2 +- ...t_metadata_google.cloud.aiplatform.v1.json | 2 +- ...adata_google.cloud.aiplatform.v1beta1.json | 2 +- 25 files changed, 48 insertions(+), 24 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 94124d32ea..2a7042a5dd 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.49.0" + ".": "1.50.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 53e1884e47..683a1d2ec5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,29 @@ # Changelog +## [1.50.0](https://github.com/googleapis/python-aiplatform/compare/v1.49.0...v1.50.0) (2024-05-02) + + +### Features + +* Add `Candidate.grounding_metadata` property ([b22a8b8](https://github.com/googleapis/python-aiplatform/commit/b22a8b847e3b299b828e37405e3678093486de28)) +* Add option to not include time_series_metrics in get_experiment_df call. This will improve execution time for Experiments with large number of runs. ([78a95c5](https://github.com/googleapis/python-aiplatform/commit/78a95c52d0e7bd9ec5b656ce67044b2f01677156)) +* Add tune_model and deploy_tuned_model for TextEmbeddingModel. ([42f5d6f](https://github.com/googleapis/python-aiplatform/commit/42f5d6f7cd13d51c4a73113c59e8b3c728cfc08b)) +* Automatically populate parents for full resource name in Vertex RAG SDK ([26657ff](https://github.com/googleapis/python-aiplatform/commit/26657ffd25ecb91882ca764e513c2e952833257f)) +* Deploy a tuned text embedding model -- it doesn't matter, if it's tuned using Node.js, or curl. ([8ca9cdf](https://github.com/googleapis/python-aiplatform/commit/8ca9cdf3576e3ce3b373ace4cd6ab0e9c54aa9f2)) +* Make get_embeddings work both for foundational & tuned models. ([b8b589c](https://github.com/googleapis/python-aiplatform/commit/b8b589ce9fff29d1721450d32b4a84a7f69413c3)) +* Python SDK for Vertex Model Monitoring V2. ([021d59f](https://github.com/googleapis/python-aiplatform/commit/021d59f1487e4e16c847d4135899d6845c0210aa)) +* Support public endpoint for Ray Client ([57a5f78](https://github.com/googleapis/python-aiplatform/commit/57a5f7815ffb8523e91d900da4ff7cfd0c344fe4)) + + +### Bug Fixes + +* Add deprecation warnings when using Ray v2.4 ([3a36784](https://github.com/googleapis/python-aiplatform/commit/3a367843840513e3257610c8ab38e9f79d3bcea0)) +* Append allowed_plugins in tb-gcp-uploader to default allowed plugins ([aab9c3e](https://github.com/googleapis/python-aiplatform/commit/aab9c3e41b92a1d60090e3d1d594390a5e9f3ff6)) +* LLM - Added missing parameters to the no-op `_TunableTextEmbeddingModelMixin.get_tuned_model` method ([eb05ac4](https://github.com/googleapis/python-aiplatform/commit/eb05ac421f186441a92c6e3b6a010d74caf14782)) +* LVM - Fixed the typo in the VisionModel aspect ratio type annotation ([2d19137](https://github.com/googleapis/python-aiplatform/commit/2d1913773cf9f4a4f8a2c8c8f45680c3ea97f68e)) +* Move torch import ([e6d34df](https://github.com/googleapis/python-aiplatform/commit/e6d34df7da7508c655eb17ee694e1ab2160fc8aa)) +* Ray - Fixed exception when using Ray 2.4 ([2661f52](https://github.com/googleapis/python-aiplatform/commit/2661f52fd08169e5d29b58f2afce9702b30101ae)) + ## [1.49.0](https://github.com/googleapis/python-aiplatform/compare/v1.48.0...v1.49.0) (2024-04-27) diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index 5f3bd3ab54..1af552a3c1 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.49.0" +__version__ = "1.50.0" diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index 41b5da9439..39995f175a 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.49.0" # {x-release-please-version} +__version__ = "1.50.0" # {x-release-please-version} diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py index 78cddd4821..de799ba97d 100644 --- a/pypi/_vertex_ai_placeholder/version.py +++ b/pypi/_vertex_ai_placeholder/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.49.0" +__version__ = "1.50.0" diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 5234f5f287..39acefe786 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.49.0" + "version": "1.50.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 27545fcbc8..79deba42e4 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.49.0" + "version": "1.50.0" }, "snippets": [ { From e8fe28d3c5df5c447a2c1bc72a04f6320db5e2ad Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Fri, 3 May 2024 08:46:24 -0700 Subject: [PATCH 03/30] Copybara import of the project: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit -- 7c236b12a435601fa399a67a54b970d6d49e67e9 by Owl Bot : feat: A new field `search_entry_point` is added to message `.google.cloud.aiplatform.v1beta1.GroundingMetadata` feat: A new message `SearchEntryPoint` is added feat: A new method `UpdateDatasetVersion` is added to service `DatasetService` feat: A new message `UpdateDatasetVersionRequest` is added feat: A new field `private_service_connect_config` is added to message `.google.cloud.aiplatform.v1beta1.Endpoint` feat: A new field `app_id` is added to message `.google.cloud.aiplatform.v1beta1.RuntimeConfig` feat: A new value `INVALID_TOKEN_VALUE` is added to enum `RecordErrorType` feat: A new field `valid_sparse_record_count` is added to message `.google.cloud.aiplatform.v1beta1.NearestNeighborSearchOperationMetadata` feat: A new field `invalid_sparse_record_count` is added to message `.google.cloud.aiplatform.v1beta1.NearestNeighborSearchOperationMetadata` feat: A new message `DirectNotebookSource` is added feat: A new message `CustomEnvironmentSpec` is added feat: A new field `direct_notebook_source` is added to message `.google.cloud.aiplatform.v1beta1.NotebookExecutionJob` feat: A new field `custom_environment_spec` is added to message `.google.cloud.aiplatform.v1beta1.NotebookExecutionJob` feat: A new message `CreateNotebookExecutionJobRequest` is added feat: A new field `deploy_task_name` is added to message `.google.cloud.aiplatform.v1beta1.PublisherModel` feat: A new field `fine_tune` is added to message `.google.cloud.aiplatform.v1beta1.PublisherModel` feat: A new field `create_notebook_execution_job_request` is added to message `.google.cloud.aiplatform.v1beta1.Schedule` feat: A new message `RagResource` is added feat: A new field `rag_resources` is added to message `.google.cloud.aiplatform.v1beta1.VertexRagStore` feat: A new field `vector_distance_threshold` is added to message `.google.cloud.aiplatform.v1beta1.VertexRagStore` feat: A new field `failed_rag_files_count` is added to message `.google.cloud.aiplatform.v1beta1.ImportRagFilesResponse` feat: A new field `skipped_rag_files_count` is added to message `.google.cloud.aiplatform.v1beta1.ImportRagFilesResponse` feat: A new field `import_rag_files_config` is added to message `.google.cloud.aiplatform.v1beta1.ImportRagFilesOperationMetadata` feat: A new message `RagResource` is added feat: A new field `rag_resources` is added to message `.google.cloud.aiplatform.v1beta1.RetrieveContextsRequest` feat: A new field `vector_distance_threshold` is added to message `.google.cloud.aiplatform.v1beta1.RetrieveContextsRequest` fix!: An existing method `ChatCompletions` is removed from service `PredictionService` fix!: An existing message `ChatCompletionsRequest` is removed docs: A comment for field `rouge_type` in message `.google.cloud.aiplatform.v1beta1.RougeSpec` is changed docs: A comment for field `file_input_gcs_bucket` in message `.google.cloud.aiplatform.v1beta1.RuntimeConfig` is changed docs: A comment for field `file_output_gcs_bucket` in message `.google.cloud.aiplatform.v1beta1.RuntimeConfig` is changed docs: A comment for field `serving_config_name` in message `.google.cloud.aiplatform.v1beta1.RuntimeConfig` is changed docs: A comment for field `big_query` in message `.google.cloud.aiplatform.v1beta1.FeatureGroup` is changed docs: A comment for field `parent` in message `.google.cloud.aiplatform.v1beta1.CreateFeatureGroupRequest` is changed docs: A comment for field `feature_vector` in message `.google.cloud.aiplatform.v1beta1.IndexDatapoint` is changed docs: A comment for field `vectors_count` in message `.google.cloud.aiplatform.v1beta1.IndexStats` is changed docs: A comment for enum value `EMBEDDING_SIZE_MISMATCH` in enum `RecordErrorType` is changed docs: A comment for field `distance` in message `.google.cloud.aiplatform.v1beta1.FindNeighborsResponse` is changed docs: A comment for field `gcs_notebook_source` in message `.google.cloud.aiplatform.v1beta1.NotebookExecutionJob` is changed docs: A comment for field `gcs_output_uri` in message `.google.cloud.aiplatform.v1beta1.NotebookExecutionJob` is changed docs: A comment for field `name` in message `.google.cloud.aiplatform.v1beta1.NotebookRuntimeTemplate` is changed docs: A comment for field `rag_corpora` in message `.google.cloud.aiplatform.v1beta1.VertexRagStore` is changed docs: A comment for field `gcs_source` in message `.google.cloud.aiplatform.v1beta1.RagFile` is changed docs: A comment for field `rag_corpora` in message `.google.cloud.aiplatform.v1beta1.RetrieveContextsRequest` is changed PiperOrigin-RevId: 629842300 Source-Link: https://github.com/googleapis/googleapis/commit/f86c175bbee8aeb00492aae9d6aacf3e7bb99789 Source-Link: https://github.com/googleapis/googleapis-gen/commit/281af7a485bfdd2e9eb280cd88f8d660a6135781 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiMjgxYWY3YTQ4NWJmZGQyZTllYjI4MGNkODhmOGQ2NjBhNjEzNTc4MSJ9 feat: A new value `TPU_V5_LITEPOD` is added to enum `AcceleratorType` feat: A new field `search_entry_point` is added to message `.google.cloud.aiplatform.v1.GroundingMetadata` feat: A new message `SearchEntryPoint` is added feat: A new field `private_service_connect_config` is added to message `.google.cloud.aiplatform.v1.Endpoint` feat: A new value `INVALID_TOKEN_VALUE` is added to enum `RecordErrorType` feat: A new field `deploy_task_name` is added to message `.google.cloud.aiplatform.v1.PublisherModel` docs: A comment for field `parent` in message `.google.cloud.aiplatform.v1.CreateFeatureGroupRequest` is changed docs: A comment for field `name` in message `.google.cloud.aiplatform.v1.NotebookRuntimeTemplate` is changed docs: A comment for field `base_model` in message `.google.cloud.aiplatform.v1.TuningJob` is changed docs: A comment for field `tuned_model_display_name` in message `.google.cloud.aiplatform.v1.TuningJob` is changed docs: A comment for field `epoch_count` in message `.google.cloud.aiplatform.v1.SupervisedHyperParameters` is changed docs: A comment for field `learning_rate_multiplier` in message `.google.cloud.aiplatform.v1.SupervisedHyperParameters` is changed docs: A comment for field `training_dataset_uri` in message `.google.cloud.aiplatform.v1.SupervisedTuningSpec` is changed docs: A comment for field `validation_dataset_uri` in message `.google.cloud.aiplatform.v1.SupervisedTuningSpec` is changed PiperOrigin-RevId: 629522152 Source-Link: https://github.com/googleapis/googleapis/commit/f9767ca40b6e547ffe59e998659326e69d71f269 Source-Link: https://github.com/googleapis/googleapis-gen/commit/95e3cf4a65180e7a8e26853a03c335c77a98d0c5 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiOTVlM2NmNGE2NTE4MGU3YThlMjY4NTNhMDNjMzM1Yzc3YTk4ZDBjNSJ9 chore: Update gapic-generator-python to v1.17.1 PiperOrigin-RevId: 629071173 Source-Link: https://github.com/googleapis/googleapis/commit/4afa392105cc62e965631d15b772ff68454ecf1c Source-Link: https://github.com/googleapis/googleapis-gen/commit/16dbbb4d0457db5e61ac9f99b0d52a46154455ac Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiMTZkYmJiNGQwNDU3ZGI1ZTYxYWM5Zjk5YjBkNTJhNDYxNTQ0NTVhYyJ9 feat: add NotebookExecutionJob resource and APIs to public v1beta1 client library PiperOrigin-RevId: 628125855 Source-Link: https://github.com/googleapis/googleapis/commit/f41b4bcf4bc6e5de0bf6de27d50e81ca7562690b Source-Link: https://github.com/googleapis/googleapis-gen/commit/2d65169af736448eb5a66646134c1440152713a7 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiMmQ2NTE2OWFmNzM2NDQ4ZWI1YTY2NjQ2MTM0YzE0NDAxNTI3MTNhNyJ9 -- 1505fe314d58f2d00d7be72a3c36d5508056e0b7 by Owl Bot : 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/3673 from googleapis:owl-bot-copy 9d828e6e86381974952a012aee4627cb77fc46ff PiperOrigin-RevId: 630401530 --- google/cloud/aiplatform_v1/__init__.py | 2 + .../services/dataset_service/async_client.py | 315 +- .../services/dataset_service/client.py | 200 +- .../dataset_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 121 +- .../async_client.py | 138 +- .../client.py | 81 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 56 +- .../services/endpoint_service/async_client.py | 152 +- .../services/endpoint_service/client.py | 107 +- .../endpoint_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 71 +- .../async_client.py | 304 +- .../client.py | 161 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 96 +- .../async_client.py | 51 +- .../feature_online_store_service/client.py | 45 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 41 +- .../feature_registry_service/async_client.py | 189 +- .../feature_registry_service/client.py | 131 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 81 +- .../async_client.py | 76 +- .../client.py | 59 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 46 +- .../featurestore_service/async_client.py | 374 +- .../services/featurestore_service/client.py | 239 +- .../featurestore_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 136 +- .../gen_ai_tuning_service/async_client.py | 85 +- .../services/gen_ai_tuning_service/client.py | 69 +- .../gen_ai_tuning_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 51 +- .../index_endpoint_service/async_client.py | 153 +- .../services/index_endpoint_service/client.py | 109 +- .../index_endpoint_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 71 +- .../services/index_service/async_client.py | 124 +- .../services/index_service/client.py | 86 +- .../services/index_service/transports/grpc.py | 25 +- .../index_service/transports/grpc_asyncio.py | 66 +- .../services/job_service/async_client.py | 626 +- .../services/job_service/client.py | 374 +- .../services/job_service/transports/grpc.py | 25 +- .../job_service/transports/grpc_asyncio.py | 206 +- .../llm_utility_service/async_client.py | 51 +- .../services/llm_utility_service/client.py | 49 +- .../llm_utility_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 41 +- .../services/match_service/async_client.py | 39 +- .../services/match_service/client.py | 36 +- .../services/match_service/transports/grpc.py | 25 +- .../match_service/transports/grpc_asyncio.py | 41 +- .../services/metadata_service/async_client.py | 565 +- .../services/metadata_service/client.py | 346 +- .../metadata_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 191 +- .../migration_service/async_client.py | 49 +- .../services/migration_service/client.py | 65 +- .../migration_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 41 +- .../model_garden_service/async_client.py | 34 +- .../services/model_garden_service/client.py | 39 +- .../model_garden_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 36 +- .../services/model_service/async_client.py | 323 +- .../services/model_service/client.py | 204 +- .../services/model_service/transports/grpc.py | 25 +- .../model_service/transports/grpc_asyncio.py | 121 +- .../services/notebook_service/async_client.py | 191 +- .../services/notebook_service/client.py | 126 +- .../notebook_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 81 +- .../async_client.py | 141 +- .../persistent_resource_service/client.py | 89 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 61 +- .../services/pipeline_service/async_client.py | 219 +- .../services/pipeline_service/client.py | 146 +- .../pipeline_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 91 +- .../prediction_service/async_client.py | 186 +- .../services/prediction_service/client.py | 107 +- .../prediction_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 96 +- .../services/schedule_service/async_client.py | 134 +- .../services/schedule_service/client.py | 96 +- .../schedule_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 66 +- .../specialist_pool_service/async_client.py | 102 +- .../specialist_pool_service/client.py | 79 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 56 +- .../tensorboard_service/async_client.py | 561 +- .../services/tensorboard_service/client.py | 329 +- .../tensorboard_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 181 +- .../services/vizier_service/async_client.py | 248 +- .../services/vizier_service/client.py | 154 +- .../vizier_service/transports/grpc.py | 25 +- .../vizier_service/transports/grpc_asyncio.py | 106 +- google/cloud/aiplatform_v1/types/__init__.py | 2 + .../aiplatform_v1/types/accelerator_type.py | 3 + google/cloud/aiplatform_v1/types/content.py | 36 + google/cloud/aiplatform_v1/types/endpoint.py | 14 + .../types/feature_registry_service.py | 2 +- .../aiplatform_v1/types/index_service.py | 3 + .../aiplatform_v1/types/notebook_runtime.py | 2 +- .../aiplatform_v1/types/publisher_model.py | 10 + .../cloud/aiplatform_v1/types/tuning_job.py | 17 +- google/cloud/aiplatform_v1beta1/__init__.py | 20 +- .../aiplatform_v1beta1/gapic_metadata.json | 75 +- .../services/dataset_service/async_client.py | 440 +- .../services/dataset_service/client.py | 322 +- .../dataset_service/transports/base.py | 18 + .../dataset_service/transports/grpc.py | 55 +- .../transports/grpc_asyncio.py | 156 +- .../dataset_service/transports/rest.py | 178 +- .../async_client.py | 138 +- .../client.py | 81 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 56 +- .../transports/rest.py | 40 - .../services/endpoint_service/async_client.py | 152 +- .../services/endpoint_service/client.py | 107 +- .../endpoint_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 71 +- .../endpoint_service/transports/rest.py | 40 - .../evaluation_service/async_client.py | 30 +- .../services/evaluation_service/client.py | 35 +- .../evaluation_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 36 +- .../evaluation_service/transports/rest.py | 20 - .../async_client.py | 51 +- .../extension_execution_service/client.py | 49 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 41 +- .../transports/rest.py | 20 - .../async_client.py | 102 +- .../extension_registry_service/client.py | 79 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 56 +- .../transports/rest.py | 40 - .../async_client.py | 304 +- .../client.py | 161 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 96 +- .../transports/rest.py | 40 - .../async_client.py | 59 +- .../feature_online_store_service/client.py | 45 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 46 +- .../transports/rest.py | 20 - .../feature_registry_service/async_client.py | 189 +- .../feature_registry_service/client.py | 131 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 81 +- .../transports/rest.py | 40 - .../async_client.py | 76 +- .../client.py | 59 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 46 +- .../transports/rest.py | 20 - .../featurestore_service/async_client.py | 374 +- .../services/featurestore_service/client.py | 239 +- .../featurestore_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 136 +- .../featurestore_service/transports/rest.py | 40 - .../index_endpoint_service/async_client.py | 153 +- .../services/index_endpoint_service/client.py | 109 +- .../index_endpoint_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 71 +- .../index_endpoint_service/transports/rest.py | 40 - .../services/index_service/async_client.py | 124 +- .../services/index_service/client.py | 86 +- .../services/index_service/transports/grpc.py | 25 +- .../index_service/transports/grpc_asyncio.py | 66 +- .../services/index_service/transports/rest.py | 40 - .../services/job_service/async_client.py | 626 +- .../services/job_service/client.py | 374 +- .../services/job_service/transports/grpc.py | 25 +- .../job_service/transports/grpc_asyncio.py | 206 +- .../services/job_service/transports/rest.py | 40 - .../llm_utility_service/async_client.py | 34 +- .../services/llm_utility_service/client.py | 39 +- .../llm_utility_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 36 +- .../llm_utility_service/transports/rest.py | 20 - .../services/match_service/async_client.py | 39 +- .../services/match_service/client.py | 36 +- .../services/match_service/transports/grpc.py | 25 +- .../match_service/transports/grpc_asyncio.py | 41 +- .../services/match_service/transports/rest.py | 20 - .../services/metadata_service/async_client.py | 565 +- .../services/metadata_service/client.py | 346 +- .../metadata_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 191 +- .../metadata_service/transports/rest.py | 40 - .../migration_service/async_client.py | 49 +- .../services/migration_service/client.py | 65 +- .../migration_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 41 +- .../migration_service/transports/rest.py | 40 - .../model_garden_service/async_client.py | 51 +- .../services/model_garden_service/client.py | 49 +- .../model_garden_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 41 +- .../model_garden_service/transports/rest.py | 20 - .../model_monitoring_service/async_client.py | 220 +- .../model_monitoring_service/client.py | 139 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 86 +- .../transports/rest.py | 40 - .../services/model_service/async_client.py | 323 +- .../services/model_service/client.py | 204 +- .../services/model_service/transports/grpc.py | 25 +- .../model_service/transports/grpc_asyncio.py | 121 +- .../services/model_service/transports/rest.py | 40 - .../services/notebook_service/async_client.py | 564 +- .../services/notebook_service/client.py | 532 +- .../services/notebook_service/pagers.py | 133 + .../notebook_service/transports/base.py | 49 + .../notebook_service/transports/grpc.py | 114 +- .../transports/grpc_asyncio.py | 186 +- .../notebook_service/transports/rest.py | 438 +- .../async_client.py | 141 +- .../persistent_resource_service/client.py | 89 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 61 +- .../transports/rest.py | 40 - .../services/pipeline_service/async_client.py | 219 +- .../services/pipeline_service/client.py | 146 +- .../pipeline_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 91 +- .../pipeline_service/transports/rest.py | 40 - .../prediction_service/async_client.py | 350 +- .../services/prediction_service/client.py | 271 +- .../prediction_service/transports/base.py | 14 - .../prediction_service/transports/grpc.py | 52 +- .../transports/grpc_asyncio.py | 123 +- .../prediction_service/transports/rest.py | 197 - .../async_client.py | 36 +- .../client.py | 35 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 36 +- .../transports/rest.py | 20 - .../reasoning_engine_service/async_client.py | 91 +- .../reasoning_engine_service/client.py | 69 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 51 +- .../transports/rest.py | 40 - .../services/schedule_service/async_client.py | 147 +- .../services/schedule_service/client.py | 141 +- .../schedule_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 66 +- .../schedule_service/transports/rest.py | 40 - .../specialist_pool_service/async_client.py | 102 +- .../specialist_pool_service/client.py | 79 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 56 +- .../transports/rest.py | 40 - .../tensorboard_service/async_client.py | 561 +- .../services/tensorboard_service/client.py | 329 +- .../tensorboard_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 181 +- .../tensorboard_service/transports/rest.py | 40 - .../vertex_rag_data_service/async_client.py | 170 +- .../vertex_rag_data_service/client.py | 119 +- .../transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 76 +- .../transports/rest.py | 40 - .../vertex_rag_service/async_client.py | 38 +- .../services/vertex_rag_service/client.py | 63 +- .../vertex_rag_service/transports/grpc.py | 25 +- .../transports/grpc_asyncio.py | 36 +- .../vertex_rag_service/transports/rest.py | 20 - .../services/vizier_service/async_client.py | 248 +- .../services/vizier_service/client.py | 154 +- .../vizier_service/transports/grpc.py | 25 +- .../vizier_service/transports/grpc_asyncio.py | 106 +- .../vizier_service/transports/rest.py | 40 - .../aiplatform_v1beta1/types/__init__.py | 22 +- .../cloud/aiplatform_v1beta1/types/content.py | 36 + .../types/dataset_service.py | 30 + .../aiplatform_v1beta1/types/endpoint.py | 15 + .../types/evaluation_service.py | 2 +- .../aiplatform_v1beta1/types/extension.py | 39 +- .../aiplatform_v1beta1/types/feature_group.py | 5 +- .../types/feature_registry_service.py | 2 +- .../aiplatform_v1beta1/types/index_service.py | 3 + .../aiplatform_v1beta1/types/match_service.py | 2 +- .../types/notebook_execution_job.py | 223 + .../types/notebook_runtime.py | 2 +- .../types/notebook_service.py | 202 + .../types/prediction_service.py | 25 - .../types/publisher_model.py | 10 + .../aiplatform_v1beta1/types/schedule.py | 11 + google/cloud/aiplatform_v1beta1/types/tool.py | 48 +- .../types/vertex_rag_data.py | 2 +- .../types/vertex_rag_data_service.py | 22 + .../types/vertex_rag_service.py | 53 +- ...et_service_update_dataset_version_async.py | 55 + ...set_service_update_dataset_version_sync.py | 55 + ...ice_delete_notebook_execution_job_async.py | 56 + ...vice_delete_notebook_execution_job_sync.py | 56 + ...ervice_get_notebook_execution_job_async.py | 52 + ...service_get_notebook_execution_job_sync.py | 52 + ...vice_list_notebook_execution_jobs_async.py | 53 + ...rvice_list_notebook_execution_jobs_sync.py | 53 + ...tex_rag_service_retrieve_contexts_async.py | 4 - ...rtex_rag_service_retrieve_contexts_sync.py | 4 - ...t_metadata_google.cloud.aiplatform.v1.json | 2 +- ...adata_google.cloud.aiplatform.v1beta1.json | 881 +- .../aiplatform_v1/test_dataset_service.py | 2402 ++++- .../test_deployment_resource_pool_service.py | 689 ++ .../aiplatform_v1/test_endpoint_service.py | 1069 +++ ...test_feature_online_store_admin_service.py | 1801 +++- .../test_feature_online_store_service.py | 264 + .../test_feature_registry_service.py | 1338 +++ ...est_featurestore_online_serving_service.py | 395 + .../test_featurestore_service.py | 2875 +++++- .../test_gen_ai_tuning_service.py | 504 + .../test_index_endpoint_service.py | 1100 +++ .../gapic/aiplatform_v1/test_index_service.py | 903 ++ .../gapic/aiplatform_v1/test_job_service.py | 4763 +++++++++- .../aiplatform_v1/test_llm_utility_service.py | 246 + .../gapic/aiplatform_v1/test_match_service.py | 256 + .../aiplatform_v1/test_metadata_service.py | 4302 ++++++++- .../aiplatform_v1/test_migration_service.py | 304 +- .../test_model_garden_service.py | 131 + .../gapic/aiplatform_v1/test_model_service.py | 2404 ++++- .../aiplatform_v1/test_notebook_service.py | 1400 +++ .../test_persistent_resource_service.py | 846 ++ .../aiplatform_v1/test_pipeline_service.py | 1648 +++- .../aiplatform_v1/test_prediction_service.py | 1484 +++ .../aiplatform_v1/test_schedule_service.py | 873 ++ .../test_specialist_pool_service.py | 699 ++ .../aiplatform_v1/test_tensorboard_service.py | 4046 +++++++- .../aiplatform_v1/test_vizier_service.py | 1905 +++- .../test_dataset_service.py | 8161 ++++++++++++----- .../test_deployment_resource_pool_service.py | 689 ++ .../test_endpoint_service.py | 1069 +++ .../test_evaluation_service.py | 131 + .../test_extension_execution_service.py | 250 + .../test_extension_registry_service.py | 657 +- ...test_feature_online_store_admin_service.py | 1801 +++- .../test_feature_online_store_service.py | 350 + .../test_feature_registry_service.py | 1338 +++ ...est_featurestore_online_serving_service.py | 395 + .../test_featurestore_service.py | 2875 +++++- .../test_index_endpoint_service.py | 1100 +++ .../aiplatform_v1beta1/test_index_service.py | 903 ++ .../aiplatform_v1beta1/test_job_service.py | 4763 +++++++++- .../test_llm_utility_service.py | 123 + .../aiplatform_v1beta1/test_match_service.py | 256 + .../test_metadata_service.py | 4302 ++++++++- .../test_migration_service.py | 304 +- .../test_model_garden_service.py | 264 + .../test_model_monitoring_service.py | 1505 ++- .../aiplatform_v1beta1/test_model_service.py | 2404 ++++- .../test_notebook_service.py | 5654 ++++++++++-- .../test_persistent_resource_service.py | 846 ++ .../test_pipeline_service.py | 1648 +++- .../test_prediction_service.py | 2557 ++++-- ...test_reasoning_engine_execution_service.py | 133 + .../test_reasoning_engine_service.py | 554 ++ .../test_schedule_service.py | 984 ++ .../test_specialist_pool_service.py | 699 ++ .../test_tensorboard_service.py | 4046 +++++++- .../test_vertex_rag_data_service.py | 1171 +++ .../test_vertex_rag_service.py | 177 +- .../aiplatform_v1beta1/test_vizier_service.py | 1905 +++- 376 files changed, 107380 insertions(+), 16757 deletions(-) create mode 100644 google/cloud/aiplatform_v1beta1/types/notebook_execution_job.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_dataset_service_update_dataset_version_async.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_dataset_service_update_dataset_version_sync.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_async.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_sync.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_async.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_sync.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_async.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_sync.py diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py index 86f821a12a..e772ab26c6 100644 --- a/google/cloud/aiplatform_v1/__init__.py +++ b/google/cloud/aiplatform_v1/__init__.py @@ -101,6 +101,7 @@ from .types.content import Part from .types.content import SafetyRating from .types.content import SafetySetting +from .types.content import SearchEntryPoint from .types.content import Segment from .types.content import VideoMetadata from .types.content import HarmCategory @@ -1400,6 +1401,7 @@ "Schema", "SearchDataItemsRequest", "SearchDataItemsResponse", + "SearchEntryPoint", "SearchFeaturesRequest", "SearchFeaturesResponse", "SearchMigratableResourcesRequest", diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py index 528b437cae..1e31c9add8 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -229,7 +230,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, DatasetServiceTransport, Callable[..., DatasetServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -241,9 +244,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DatasetServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DatasetServiceTransport,Callable[..., DatasetServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DatasetServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -371,8 +376,8 @@ async def sample_create_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: raise ValueError( @@ -380,7 +385,10 @@ async def sample_create_dataset(): "the individual field arguments should be set." ) - request = dataset_service.CreateDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.CreateDatasetRequest): + request = dataset_service.CreateDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -391,11 +399,9 @@ async def sample_create_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_dataset, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -486,8 +492,8 @@ async def sample_get_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -495,7 +501,10 @@ async def sample_get_dataset(): "the individual field arguments should be set." ) - request = dataset_service.GetDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.GetDatasetRequest): + request = dataset_service.GetDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -504,11 +513,9 @@ async def sample_get_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_dataset, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -610,8 +617,8 @@ async def sample_update_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -619,7 +626,10 @@ async def sample_update_dataset(): "the individual field arguments should be set." ) - request = dataset_service.UpdateDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.UpdateDatasetRequest): + request = dataset_service.UpdateDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -630,11 +640,9 @@ async def sample_update_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_dataset, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -723,8 +731,8 @@ async def sample_list_datasets(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -732,7 +740,10 @@ async def sample_list_datasets(): "the individual field arguments should be set." ) - request = dataset_service.ListDatasetsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListDatasetsRequest): + request = dataset_service.ListDatasetsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -741,11 +752,9 @@ async def sample_list_datasets(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_datasets, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_datasets + ] # Certain fields should be provided within the metadata header; # add these here. @@ -852,8 +861,8 @@ async def sample_delete_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -861,7 +870,10 @@ async def sample_delete_dataset(): "the individual field arguments should be set." ) - request = dataset_service.DeleteDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.DeleteDatasetRequest): + request = dataset_service.DeleteDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -870,11 +882,9 @@ async def sample_delete_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_dataset, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -985,8 +995,8 @@ async def sample_import_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: raise ValueError( @@ -994,7 +1004,10 @@ async def sample_import_data(): "the individual field arguments should be set." ) - request = dataset_service.ImportDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ImportDataRequest): + request = dataset_service.ImportDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1005,11 +1018,9 @@ async def sample_import_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1118,8 +1129,8 @@ async def sample_export_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: raise ValueError( @@ -1127,7 +1138,10 @@ async def sample_export_data(): "the individual field arguments should be set." ) - request = dataset_service.ExportDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ExportDataRequest): + request = dataset_service.ExportDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1138,11 +1152,9 @@ async def sample_export_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1257,8 +1269,8 @@ async def sample_create_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset_version]) if request is not None and has_flattened_params: raise ValueError( @@ -1266,7 +1278,10 @@ async def sample_create_dataset_version(): "the individual field arguments should be set." ) - request = dataset_service.CreateDatasetVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.CreateDatasetVersionRequest): + request = dataset_service.CreateDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1277,11 +1292,9 @@ async def sample_create_dataset_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_dataset_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_dataset_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1389,8 +1402,8 @@ async def sample_delete_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1398,7 +1411,10 @@ async def sample_delete_dataset_version(): "the individual field arguments should be set." ) - request = dataset_service.DeleteDatasetVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.DeleteDatasetVersionRequest): + request = dataset_service.DeleteDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1407,11 +1423,9 @@ async def sample_delete_dataset_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_dataset_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_dataset_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1501,8 +1515,8 @@ async def sample_get_dataset_version(): Describes the dataset version. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1510,7 +1524,10 @@ async def sample_get_dataset_version(): "the individual field arguments should be set." ) - request = dataset_service.GetDatasetVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.GetDatasetVersionRequest): + request = dataset_service.GetDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1519,11 +1536,9 @@ async def sample_get_dataset_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_dataset_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_dataset_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1613,8 +1628,8 @@ async def sample_list_dataset_versions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1622,7 +1637,10 @@ async def sample_list_dataset_versions(): "the individual field arguments should be set." ) - request = dataset_service.ListDatasetVersionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListDatasetVersionsRequest): + request = dataset_service.ListDatasetVersionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1631,11 +1649,9 @@ async def sample_list_dataset_versions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_dataset_versions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_dataset_versions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1737,8 +1753,8 @@ async def sample_restore_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1746,7 +1762,10 @@ async def sample_restore_dataset_version(): "the individual field arguments should be set." ) - request = dataset_service.RestoreDatasetVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.RestoreDatasetVersionRequest): + request = dataset_service.RestoreDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1755,11 +1774,9 @@ async def sample_restore_dataset_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.restore_dataset_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.restore_dataset_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1855,8 +1872,8 @@ async def sample_list_data_items(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1864,7 +1881,10 @@ async def sample_list_data_items(): "the individual field arguments should be set." ) - request = dataset_service.ListDataItemsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListDataItemsRequest): + request = dataset_service.ListDataItemsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1873,11 +1893,9 @@ async def sample_list_data_items(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_data_items, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_data_items + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1966,15 +1984,16 @@ async def sample_search_data_items(): """ # Create or coerce a protobuf request object. - request = dataset_service.SearchDataItemsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.SearchDataItemsRequest): + request = dataset_service.SearchDataItemsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_data_items, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_data_items + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2071,8 +2090,8 @@ async def sample_list_saved_queries(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2080,7 +2099,10 @@ async def sample_list_saved_queries(): "the individual field arguments should be set." ) - request = dataset_service.ListSavedQueriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListSavedQueriesRequest): + request = dataset_service.ListSavedQueriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2089,11 +2111,9 @@ async def sample_list_saved_queries(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_saved_queries, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_saved_queries + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2200,8 +2220,8 @@ async def sample_delete_saved_query(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2209,7 +2229,10 @@ async def sample_delete_saved_query(): "the individual field arguments should be set." ) - request = dataset_service.DeleteSavedQueryRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.DeleteSavedQueryRequest): + request = dataset_service.DeleteSavedQueryRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2218,11 +2241,9 @@ async def sample_delete_saved_query(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_saved_query, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_saved_query + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2314,8 +2335,8 @@ async def sample_get_annotation_spec(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2323,7 +2344,10 @@ async def sample_get_annotation_spec(): "the individual field arguments should be set." ) - request = dataset_service.GetAnnotationSpecRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.GetAnnotationSpecRequest): + request = dataset_service.GetAnnotationSpecRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2332,11 +2356,9 @@ async def sample_get_annotation_spec(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_annotation_spec, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_annotation_spec + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2424,8 +2446,8 @@ async def sample_list_annotations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2433,7 +2455,10 @@ async def sample_list_annotations(): "the individual field arguments should be set." ) - request = dataset_service.ListAnnotationsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListAnnotationsRequest): + request = dataset_service.ListAnnotationsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2442,11 +2467,9 @@ async def sample_list_annotations(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_annotations, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_annotations + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index 5876faca8d..bc32600bc8 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -672,7 +673,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, DatasetServiceTransport]] = None, + transport: Optional[ + Union[str, DatasetServiceTransport, Callable[..., DatasetServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -684,9 +687,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, DatasetServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DatasetServiceTransport,Callable[..., DatasetServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DatasetServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -798,8 +803,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[DatasetServiceTransport], Callable[..., DatasetServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., DatasetServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -892,8 +904,8 @@ def sample_create_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: raise ValueError( @@ -901,10 +913,8 @@ def sample_create_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.CreateDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.CreateDatasetRequest): request = dataset_service.CreateDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1007,8 +1017,8 @@ def sample_get_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1016,10 +1026,8 @@ def sample_get_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.GetDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.GetDatasetRequest): request = dataset_service.GetDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1131,8 +1139,8 @@ def sample_update_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1140,10 +1148,8 @@ def sample_update_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.UpdateDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.UpdateDatasetRequest): request = dataset_service.UpdateDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1244,8 +1250,8 @@ def sample_list_datasets(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1253,10 +1259,8 @@ def sample_list_datasets(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListDatasetsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListDatasetsRequest): request = dataset_service.ListDatasetsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1373,8 +1377,8 @@ def sample_delete_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1382,10 +1386,8 @@ def sample_delete_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.DeleteDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.DeleteDatasetRequest): request = dataset_service.DeleteDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1506,8 +1508,8 @@ def sample_import_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: raise ValueError( @@ -1515,10 +1517,8 @@ def sample_import_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ImportDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ImportDataRequest): request = dataset_service.ImportDataRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1639,8 +1639,8 @@ def sample_export_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: raise ValueError( @@ -1648,10 +1648,8 @@ def sample_export_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ExportDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ExportDataRequest): request = dataset_service.ExportDataRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1778,8 +1776,8 @@ def sample_create_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset_version]) if request is not None and has_flattened_params: raise ValueError( @@ -1787,10 +1785,8 @@ def sample_create_dataset_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.CreateDatasetVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.CreateDatasetVersionRequest): request = dataset_service.CreateDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1910,8 +1906,8 @@ def sample_delete_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1919,10 +1915,8 @@ def sample_delete_dataset_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.DeleteDatasetVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.DeleteDatasetVersionRequest): request = dataset_service.DeleteDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2022,8 +2016,8 @@ def sample_get_dataset_version(): Describes the dataset version. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2031,10 +2025,8 @@ def sample_get_dataset_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.GetDatasetVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.GetDatasetVersionRequest): request = dataset_service.GetDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2134,8 +2126,8 @@ def sample_list_dataset_versions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2143,10 +2135,8 @@ def sample_list_dataset_versions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListDatasetVersionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListDatasetVersionsRequest): request = dataset_service.ListDatasetVersionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2258,8 +2248,8 @@ def sample_restore_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2267,10 +2257,8 @@ def sample_restore_dataset_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.RestoreDatasetVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.RestoreDatasetVersionRequest): request = dataset_service.RestoreDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2376,8 +2364,8 @@ def sample_list_data_items(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2385,10 +2373,8 @@ def sample_list_data_items(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListDataItemsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListDataItemsRequest): request = dataset_service.ListDataItemsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2487,10 +2473,8 @@ def sample_search_data_items(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.SearchDataItemsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.SearchDataItemsRequest): request = dataset_service.SearchDataItemsRequest(request) @@ -2593,8 +2577,8 @@ def sample_list_saved_queries(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2602,10 +2586,8 @@ def sample_list_saved_queries(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListSavedQueriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListSavedQueriesRequest): request = dataset_service.ListSavedQueriesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2722,8 +2704,8 @@ def sample_delete_saved_query(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2731,10 +2713,8 @@ def sample_delete_saved_query(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.DeleteSavedQueryRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.DeleteSavedQueryRequest): request = dataset_service.DeleteSavedQueryRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2836,8 +2816,8 @@ def sample_get_annotation_spec(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2845,10 +2825,8 @@ def sample_get_annotation_spec(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.GetAnnotationSpecRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.GetAnnotationSpecRequest): request = dataset_service.GetAnnotationSpecRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2946,8 +2924,8 @@ def sample_list_annotations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2955,10 +2933,8 @@ def sample_list_annotations(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListAnnotationsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListAnnotationsRequest): request = dataset_service.ListAnnotationsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py index 52b88fc971..217f2aaab6 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py @@ -60,7 +60,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -80,14 +80,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -97,11 +100,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -128,7 +131,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -169,7 +172,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py index 2afd3dd0cb..01027107c4 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -75,7 +77,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -105,7 +106,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -125,15 +126,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -143,11 +147,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -174,7 +178,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -214,7 +218,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -772,6 +778,101 @@ def list_annotations( ) return self._stubs["list_annotations"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_dataset: gapic_v1.method_async.wrap_method( + self.create_dataset, + default_timeout=None, + client_info=client_info, + ), + self.get_dataset: gapic_v1.method_async.wrap_method( + self.get_dataset, + default_timeout=None, + client_info=client_info, + ), + self.update_dataset: gapic_v1.method_async.wrap_method( + self.update_dataset, + default_timeout=None, + client_info=client_info, + ), + self.list_datasets: gapic_v1.method_async.wrap_method( + self.list_datasets, + default_timeout=None, + client_info=client_info, + ), + self.delete_dataset: gapic_v1.method_async.wrap_method( + self.delete_dataset, + default_timeout=None, + client_info=client_info, + ), + self.import_data: gapic_v1.method_async.wrap_method( + self.import_data, + default_timeout=None, + client_info=client_info, + ), + self.export_data: gapic_v1.method_async.wrap_method( + self.export_data, + default_timeout=None, + client_info=client_info, + ), + self.create_dataset_version: gapic_v1.method_async.wrap_method( + self.create_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.delete_dataset_version: gapic_v1.method_async.wrap_method( + self.delete_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.get_dataset_version: gapic_v1.method_async.wrap_method( + self.get_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.list_dataset_versions: gapic_v1.method_async.wrap_method( + self.list_dataset_versions, + default_timeout=None, + client_info=client_info, + ), + self.restore_dataset_version: gapic_v1.method_async.wrap_method( + self.restore_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.list_data_items: gapic_v1.method_async.wrap_method( + self.list_data_items, + default_timeout=None, + client_info=client_info, + ), + self.search_data_items: gapic_v1.method_async.wrap_method( + self.search_data_items, + default_timeout=None, + client_info=client_info, + ), + self.list_saved_queries: gapic_v1.method_async.wrap_method( + self.list_saved_queries, + default_timeout=None, + client_info=client_info, + ), + self.delete_saved_query: gapic_v1.method_async.wrap_method( + self.delete_saved_query, + default_timeout=None, + client_info=client_info, + ), + self.get_annotation_spec: gapic_v1.method_async.wrap_method( + self.get_annotation_spec, + default_timeout=None, + client_info=client_info, + ), + self.list_annotations: gapic_v1.method_async.wrap_method( + self.list_annotations, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/async_client.py index 89bab43fd9..4765925d21 100644 --- a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -229,7 +230,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, DeploymentResourcePoolServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + DeploymentResourcePoolServiceTransport, + Callable[..., DeploymentResourcePoolServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -241,9 +248,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DeploymentResourcePoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DeploymentResourcePoolServiceTransport,Callable[..., DeploymentResourcePoolServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DeploymentResourcePoolServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -391,8 +400,8 @@ async def sample_create_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, deployment_resource_pool, deployment_resource_pool_id] ) @@ -402,9 +411,17 @@ async def sample_create_deployment_resource_pool(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.CreateDeploymentResourcePoolRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, + deployment_resource_pool_service.CreateDeploymentResourcePoolRequest, + ): + request = ( + deployment_resource_pool_service.CreateDeploymentResourcePoolRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -417,11 +434,9 @@ async def sample_create_deployment_resource_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_deployment_resource_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_deployment_resource_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -519,8 +534,8 @@ async def sample_get_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -528,9 +543,14 @@ async def sample_get_deployment_resource_pool(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.GetDeploymentResourcePoolRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, deployment_resource_pool_service.GetDeploymentResourcePoolRequest + ): + request = deployment_resource_pool_service.GetDeploymentResourcePoolRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -539,11 +559,9 @@ async def sample_get_deployment_resource_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_deployment_resource_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_deployment_resource_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -636,8 +654,8 @@ async def sample_list_deployment_resource_pools(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -645,9 +663,16 @@ async def sample_list_deployment_resource_pools(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.ListDeploymentResourcePoolsRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, deployment_resource_pool_service.ListDeploymentResourcePoolsRequest + ): + request = ( + deployment_resource_pool_service.ListDeploymentResourcePoolsRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -656,11 +681,9 @@ async def sample_list_deployment_resource_pools(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_deployment_resource_pools, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_deployment_resource_pools + ] # Certain fields should be provided within the metadata header; # add these here. @@ -772,8 +795,8 @@ async def sample_delete_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -781,9 +804,17 @@ async def sample_delete_deployment_resource_pool(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, + deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest, + ): + request = ( + deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -792,11 +823,9 @@ async def sample_delete_deployment_resource_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_deployment_resource_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_deployment_resource_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -895,8 +924,8 @@ async def sample_query_deployed_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([deployment_resource_pool]) if request is not None and has_flattened_params: raise ValueError( @@ -904,7 +933,14 @@ async def sample_query_deployed_models(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.QueryDeployedModelsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, deployment_resource_pool_service.QueryDeployedModelsRequest + ): + request = deployment_resource_pool_service.QueryDeployedModelsRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -913,11 +949,9 @@ async def sample_query_deployed_models(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_deployed_models, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_deployed_models + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/client.py b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/client.py index 70f2a68e94..c2a613a41b 100644 --- a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -593,7 +594,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, DeploymentResourcePoolServiceTransport]] = None, + transport: Optional[ + Union[ + str, + DeploymentResourcePoolServiceTransport, + Callable[..., DeploymentResourcePoolServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -605,9 +612,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, DeploymentResourcePoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DeploymentResourcePoolServiceTransport,Callable[..., DeploymentResourcePoolServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DeploymentResourcePoolServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -725,8 +734,18 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[DeploymentResourcePoolServiceTransport], + Callable[..., DeploymentResourcePoolServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast( + Callable[..., DeploymentResourcePoolServiceTransport], transport + ) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -839,8 +858,8 @@ def sample_create_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, deployment_resource_pool, deployment_resource_pool_id] ) @@ -850,10 +869,8 @@ def sample_create_deployment_resource_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.CreateDeploymentResourcePoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.CreateDeploymentResourcePoolRequest, @@ -974,8 +991,8 @@ def sample_get_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -983,10 +1000,8 @@ def sample_get_deployment_resource_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.GetDeploymentResourcePoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.GetDeploymentResourcePoolRequest ): @@ -1095,8 +1110,8 @@ def sample_list_deployment_resource_pools(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1104,10 +1119,8 @@ def sample_list_deployment_resource_pools(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.ListDeploymentResourcePoolsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.ListDeploymentResourcePoolsRequest ): @@ -1237,8 +1250,8 @@ def sample_delete_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1246,10 +1259,8 @@ def sample_delete_deployment_resource_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest, @@ -1367,8 +1378,8 @@ def sample_query_deployed_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([deployment_resource_pool]) if request is not None and has_flattened_params: raise ValueError( @@ -1376,10 +1387,8 @@ def sample_query_deployed_models(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.QueryDeployedModelsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.QueryDeployedModelsRequest ): diff --git a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/transports/grpc.py index b133b91dcb..b4a6e08d79 100644 --- a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/transports/grpc.py @@ -58,7 +58,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -78,14 +78,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -95,11 +98,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -126,7 +129,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -167,7 +170,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/transports/grpc_asyncio.py index 2145527b09..ba468f4704 100644 --- a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -73,7 +75,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -103,7 +104,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -123,15 +124,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -141,11 +145,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -172,7 +176,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -212,7 +216,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -411,6 +417,36 @@ def query_deployed_models( ) return self._stubs["query_deployed_models"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_deployment_resource_pool: gapic_v1.method_async.wrap_method( + self.create_deployment_resource_pool, + default_timeout=None, + client_info=client_info, + ), + self.get_deployment_resource_pool: gapic_v1.method_async.wrap_method( + self.get_deployment_resource_pool, + default_timeout=None, + client_info=client_info, + ), + self.list_deployment_resource_pools: gapic_v1.method_async.wrap_method( + self.list_deployment_resource_pools, + default_timeout=None, + client_info=client_info, + ), + self.delete_deployment_resource_pool: gapic_v1.method_async.wrap_method( + self.delete_deployment_resource_pool, + default_timeout=None, + client_info=client_info, + ), + self.query_deployed_models: gapic_v1.method_async.wrap_method( + self.query_deployed_models, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index 447c616574..4c0b82c0e8 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -50,6 +51,7 @@ from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint from google.cloud.aiplatform_v1.types import endpoint_service from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import service_networking from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -222,7 +224,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, EndpointServiceTransport, Callable[..., EndpointServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -234,9 +240,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.EndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,EndpointServiceTransport,Callable[..., EndpointServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the EndpointServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -385,8 +393,8 @@ async def sample_create_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint, endpoint_id]) if request is not None and has_flattened_params: raise ValueError( @@ -394,7 +402,10 @@ async def sample_create_endpoint(): "the individual field arguments should be set." ) - request = endpoint_service.CreateEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.CreateEndpointRequest): + request = endpoint_service.CreateEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -407,11 +418,9 @@ async def sample_create_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_endpoint, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -503,8 +512,8 @@ async def sample_get_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -512,7 +521,10 @@ async def sample_get_endpoint(): "the individual field arguments should be set." ) - request = endpoint_service.GetEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.GetEndpointRequest): + request = endpoint_service.GetEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -521,11 +533,9 @@ async def sample_get_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_endpoint, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -613,8 +623,8 @@ async def sample_list_endpoints(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -622,7 +632,10 @@ async def sample_list_endpoints(): "the individual field arguments should be set." ) - request = endpoint_service.ListEndpointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.ListEndpointsRequest): + request = endpoint_service.ListEndpointsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -631,11 +644,9 @@ async def sample_list_endpoints(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_endpoints, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_endpoints + ] # Certain fields should be provided within the metadata header; # add these here. @@ -739,8 +750,8 @@ async def sample_update_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -748,7 +759,10 @@ async def sample_update_endpoint(): "the individual field arguments should be set." ) - request = endpoint_service.UpdateEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.UpdateEndpointRequest): + request = endpoint_service.UpdateEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -759,11 +773,9 @@ async def sample_update_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_endpoint, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -863,8 +875,8 @@ async def sample_delete_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -872,7 +884,10 @@ async def sample_delete_endpoint(): "the individual field arguments should be set." ) - request = endpoint_service.DeleteEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.DeleteEndpointRequest): + request = endpoint_service.DeleteEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -881,11 +896,9 @@ async def sample_delete_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_endpoint, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1022,8 +1035,8 @@ async def sample_deploy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: raise ValueError( @@ -1031,7 +1044,10 @@ async def sample_deploy_model(): "the individual field arguments should be set." ) - request = endpoint_service.DeployModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.DeployModelRequest): + request = endpoint_service.DeployModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1045,11 +1061,9 @@ async def sample_deploy_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.deploy_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.deploy_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1173,8 +1187,8 @@ async def sample_undeploy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: raise ValueError( @@ -1182,7 +1196,10 @@ async def sample_undeploy_model(): "the individual field arguments should be set." ) - request = endpoint_service.UndeployModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.UndeployModelRequest): + request = endpoint_service.UndeployModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1196,11 +1213,9 @@ async def sample_undeploy_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.undeploy_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.undeploy_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1336,8 +1351,8 @@ async def sample_mutate_deployed_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1345,7 +1360,10 @@ async def sample_mutate_deployed_model(): "the individual field arguments should be set." ) - request = endpoint_service.MutateDeployedModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.MutateDeployedModelRequest): + request = endpoint_service.MutateDeployedModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1358,11 +1376,9 @@ async def sample_mutate_deployed_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.mutate_deployed_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.mutate_deployed_model + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 4c6fad2571..a9f0dc8e0b 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -55,6 +56,7 @@ from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint from google.cloud.aiplatform_v1.types import endpoint_service from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import service_networking from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -626,7 +628,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, EndpointServiceTransport]] = None, + transport: Optional[ + Union[ + str, EndpointServiceTransport, Callable[..., EndpointServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -638,9 +644,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, EndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,EndpointServiceTransport,Callable[..., EndpointServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the EndpointServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -752,8 +760,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[EndpointServiceTransport], Callable[..., EndpointServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., EndpointServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -867,8 +882,8 @@ def sample_create_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint, endpoint_id]) if request is not None and has_flattened_params: raise ValueError( @@ -876,10 +891,8 @@ def sample_create_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.CreateEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.CreateEndpointRequest): request = endpoint_service.CreateEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -985,8 +998,8 @@ def sample_get_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -994,10 +1007,8 @@ def sample_get_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.GetEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.GetEndpointRequest): request = endpoint_service.GetEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1095,8 +1106,8 @@ def sample_list_endpoints(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1104,10 +1115,8 @@ def sample_list_endpoints(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.ListEndpointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.ListEndpointsRequest): request = endpoint_service.ListEndpointsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1221,8 +1230,8 @@ def sample_update_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1230,10 +1239,8 @@ def sample_update_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.UpdateEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.UpdateEndpointRequest): request = endpoint_service.UpdateEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1345,8 +1352,8 @@ def sample_delete_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1354,10 +1361,8 @@ def sample_delete_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.DeleteEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.DeleteEndpointRequest): request = endpoint_service.DeleteEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1504,8 +1509,8 @@ def sample_deploy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: raise ValueError( @@ -1513,10 +1518,8 @@ def sample_deploy_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.DeployModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.DeployModelRequest): request = endpoint_service.DeployModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1654,8 +1657,8 @@ def sample_undeploy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: raise ValueError( @@ -1663,10 +1666,8 @@ def sample_undeploy_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.UndeployModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.UndeployModelRequest): request = endpoint_service.UndeployModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1816,8 +1817,8 @@ def sample_mutate_deployed_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1825,10 +1826,8 @@ def sample_mutate_deployed_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.MutateDeployedModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.MutateDeployedModelRequest): request = endpoint_service.MutateDeployedModelRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py index 39dfd7c077..9755742b3e 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py index b3e4da59bb..ad4d7b7e9d 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -486,6 +492,51 @@ def mutate_deployed_model( ) return self._stubs["mutate_deployed_model"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_endpoint: gapic_v1.method_async.wrap_method( + self.create_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.get_endpoint: gapic_v1.method_async.wrap_method( + self.get_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.list_endpoints: gapic_v1.method_async.wrap_method( + self.list_endpoints, + default_timeout=None, + client_info=client_info, + ), + self.update_endpoint: gapic_v1.method_async.wrap_method( + self.update_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.delete_endpoint: gapic_v1.method_async.wrap_method( + self.delete_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.deploy_model: gapic_v1.method_async.wrap_method( + self.deploy_model, + default_timeout=None, + client_info=client_info, + ), + self.undeploy_model: gapic_v1.method_async.wrap_method( + self.undeploy_model, + default_timeout=None, + client_info=client_info, + ), + self.mutate_deployed_model: gapic_v1.method_async.wrap_method( + self.mutate_deployed_model, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/async_client.py b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/async_client.py index 4d70014f20..c61177eb13 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -244,7 +245,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, FeatureOnlineStoreAdminServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + FeatureOnlineStoreAdminServiceTransport, + Callable[..., FeatureOnlineStoreAdminServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -256,9 +263,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeatureOnlineStoreAdminServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureOnlineStoreAdminServiceTransport,Callable[..., FeatureOnlineStoreAdminServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureOnlineStoreAdminServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -411,8 +420,8 @@ async def sample_create_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, feature_online_store, feature_online_store_id] ) @@ -422,9 +431,16 @@ async def sample_create_feature_online_store(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.CreateFeatureOnlineStoreRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.CreateFeatureOnlineStoreRequest + ): + request = ( + feature_online_store_admin_service.CreateFeatureOnlineStoreRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -437,11 +453,9 @@ async def sample_create_feature_online_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature_online_store, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature_online_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -537,8 +551,8 @@ async def sample_get_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -546,9 +560,14 @@ async def sample_get_feature_online_store(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.GetFeatureOnlineStoreRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.GetFeatureOnlineStoreRequest + ): + request = feature_online_store_admin_service.GetFeatureOnlineStoreRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -557,11 +576,9 @@ async def sample_get_feature_online_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature_online_store, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature_online_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -654,8 +671,8 @@ async def sample_list_feature_online_stores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -663,9 +680,14 @@ async def sample_list_feature_online_stores(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.ListFeatureOnlineStoresRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.ListFeatureOnlineStoresRequest + ): + request = feature_online_store_admin_service.ListFeatureOnlineStoresRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -674,11 +696,9 @@ async def sample_list_feature_online_stores(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_feature_online_stores, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_feature_online_stores + ] # Certain fields should be provided within the metadata header; # add these here. @@ -811,8 +831,8 @@ async def sample_update_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_online_store, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -820,9 +840,16 @@ async def sample_update_feature_online_store(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest + ): + request = ( + feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -833,11 +860,9 @@ async def sample_update_feature_online_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature_online_store, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature_online_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -961,8 +986,8 @@ async def sample_delete_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -970,9 +995,16 @@ async def sample_delete_feature_online_store(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest + ): + request = ( + feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -983,11 +1015,9 @@ async def sample_delete_feature_online_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature_online_store, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature_online_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1115,8 +1145,8 @@ async def sample_create_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature_view, feature_view_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1124,7 +1154,14 @@ async def sample_create_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.CreateFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.CreateFeatureViewRequest + ): + request = feature_online_store_admin_service.CreateFeatureViewRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1137,11 +1174,9 @@ async def sample_create_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature_view, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1235,8 +1270,8 @@ async def sample_get_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1244,7 +1279,12 @@ async def sample_get_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.GetFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.GetFeatureViewRequest + ): + request = feature_online_store_admin_service.GetFeatureViewRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1253,11 +1293,9 @@ async def sample_get_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature_view, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1347,8 +1385,8 @@ async def sample_list_feature_views(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1356,7 +1394,14 @@ async def sample_list_feature_views(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.ListFeatureViewsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.ListFeatureViewsRequest + ): + request = feature_online_store_admin_service.ListFeatureViewsRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1365,11 +1410,9 @@ async def sample_list_feature_views(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_feature_views, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_feature_views + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1493,8 +1536,8 @@ async def sample_update_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1502,7 +1545,14 @@ async def sample_update_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.UpdateFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.UpdateFeatureViewRequest + ): + request = feature_online_store_admin_service.UpdateFeatureViewRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1513,11 +1563,9 @@ async def sample_update_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature_view, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1627,8 +1675,8 @@ async def sample_delete_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1636,7 +1684,14 @@ async def sample_delete_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.DeleteFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.DeleteFeatureViewRequest + ): + request = feature_online_store_admin_service.DeleteFeatureViewRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1645,11 +1700,9 @@ async def sample_delete_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature_view, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1742,8 +1795,8 @@ async def sample_sync_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view]) if request is not None and has_flattened_params: raise ValueError( @@ -1751,7 +1804,12 @@ async def sample_sync_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.SyncFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.SyncFeatureViewRequest + ): + request = feature_online_store_admin_service.SyncFeatureViewRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1760,11 +1818,9 @@ async def sample_sync_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.sync_feature_view, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.sync_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1854,8 +1910,8 @@ async def sample_get_feature_view_sync(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1863,7 +1919,14 @@ async def sample_get_feature_view_sync(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.GetFeatureViewSyncRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.GetFeatureViewSyncRequest + ): + request = feature_online_store_admin_service.GetFeatureViewSyncRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1872,11 +1935,9 @@ async def sample_get_feature_view_sync(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature_view_sync, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature_view_sync + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1966,8 +2027,8 @@ async def sample_list_feature_view_syncs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1975,9 +2036,14 @@ async def sample_list_feature_view_syncs(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.ListFeatureViewSyncsRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.ListFeatureViewSyncsRequest + ): + request = feature_online_store_admin_service.ListFeatureViewSyncsRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1986,11 +2052,9 @@ async def sample_list_feature_view_syncs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_feature_view_syncs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_feature_view_syncs + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/client.py b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/client.py index 225638640c..13adc9fe55 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/client.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -608,7 +609,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, FeatureOnlineStoreAdminServiceTransport]] = None, + transport: Optional[ + Union[ + str, + FeatureOnlineStoreAdminServiceTransport, + Callable[..., FeatureOnlineStoreAdminServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -620,9 +627,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeatureOnlineStoreAdminServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureOnlineStoreAdminServiceTransport,Callable[..., FeatureOnlineStoreAdminServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureOnlineStoreAdminServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -740,8 +749,18 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeatureOnlineStoreAdminServiceTransport], + Callable[..., FeatureOnlineStoreAdminServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast( + Callable[..., FeatureOnlineStoreAdminServiceTransport], transport + ) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -859,8 +878,8 @@ def sample_create_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, feature_online_store, feature_online_store_id] ) @@ -870,10 +889,8 @@ def sample_create_feature_online_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.CreateFeatureOnlineStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.CreateFeatureOnlineStoreRequest ): @@ -991,8 +1008,8 @@ def sample_get_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1000,10 +1017,8 @@ def sample_get_feature_online_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.GetFeatureOnlineStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.GetFeatureOnlineStoreRequest ): @@ -1110,8 +1125,8 @@ def sample_list_feature_online_stores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1119,10 +1134,8 @@ def sample_list_feature_online_stores(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.ListFeatureOnlineStoresRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.ListFeatureOnlineStoresRequest ): @@ -1271,8 +1284,8 @@ def sample_update_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_online_store, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1280,10 +1293,8 @@ def sample_update_feature_online_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest ): @@ -1427,8 +1438,8 @@ def sample_delete_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -1436,10 +1447,8 @@ def sample_delete_feature_online_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest ): @@ -1587,8 +1596,8 @@ def sample_create_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature_view, feature_view_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1596,10 +1605,8 @@ def sample_create_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.CreateFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.CreateFeatureViewRequest ): @@ -1711,8 +1718,8 @@ def sample_get_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1720,10 +1727,8 @@ def sample_get_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.GetFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.GetFeatureViewRequest ): @@ -1825,8 +1830,8 @@ def sample_list_feature_views(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1834,10 +1839,8 @@ def sample_list_feature_views(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.ListFeatureViewsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.ListFeatureViewsRequest ): @@ -1975,8 +1978,8 @@ def sample_update_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1984,10 +1987,8 @@ def sample_update_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.UpdateFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.UpdateFeatureViewRequest ): @@ -2113,8 +2114,8 @@ def sample_delete_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2122,10 +2123,8 @@ def sample_delete_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.DeleteFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.DeleteFeatureViewRequest ): @@ -2232,8 +2231,8 @@ def sample_sync_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view]) if request is not None and has_flattened_params: raise ValueError( @@ -2241,10 +2240,8 @@ def sample_sync_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.SyncFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.SyncFeatureViewRequest ): @@ -2346,8 +2343,8 @@ def sample_get_feature_view_sync(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2355,10 +2352,8 @@ def sample_get_feature_view_sync(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.GetFeatureViewSyncRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.GetFeatureViewSyncRequest ): @@ -2462,8 +2457,8 @@ def sample_list_feature_view_syncs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2471,10 +2466,8 @@ def sample_list_feature_view_syncs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.ListFeatureViewSyncsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.ListFeatureViewSyncsRequest ): diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/transports/grpc.py index 3c397adc02..cf29a41abf 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/transports/grpc.py @@ -61,7 +61,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -81,14 +81,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -98,11 +101,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -129,7 +132,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -170,7 +173,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/transports/grpc_asyncio.py index db4e6be5d7..4220625ac0 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -76,7 +78,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -106,7 +107,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -126,15 +127,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -144,11 +148,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -175,7 +179,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -215,7 +219,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -642,6 +648,76 @@ def list_feature_view_syncs( ) return self._stubs["list_feature_view_syncs"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_feature_online_store: gapic_v1.method_async.wrap_method( + self.create_feature_online_store, + default_timeout=None, + client_info=client_info, + ), + self.get_feature_online_store: gapic_v1.method_async.wrap_method( + self.get_feature_online_store, + default_timeout=None, + client_info=client_info, + ), + self.list_feature_online_stores: gapic_v1.method_async.wrap_method( + self.list_feature_online_stores, + default_timeout=None, + client_info=client_info, + ), + self.update_feature_online_store: gapic_v1.method_async.wrap_method( + self.update_feature_online_store, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature_online_store: gapic_v1.method_async.wrap_method( + self.delete_feature_online_store, + default_timeout=None, + client_info=client_info, + ), + self.create_feature_view: gapic_v1.method_async.wrap_method( + self.create_feature_view, + default_timeout=None, + client_info=client_info, + ), + self.get_feature_view: gapic_v1.method_async.wrap_method( + self.get_feature_view, + default_timeout=None, + client_info=client_info, + ), + self.list_feature_views: gapic_v1.method_async.wrap_method( + self.list_feature_views, + default_timeout=None, + client_info=client_info, + ), + self.update_feature_view: gapic_v1.method_async.wrap_method( + self.update_feature_view, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature_view: gapic_v1.method_async.wrap_method( + self.delete_feature_view, + default_timeout=None, + client_info=client_info, + ), + self.sync_feature_view: gapic_v1.method_async.wrap_method( + self.sync_feature_view, + default_timeout=None, + client_info=client_info, + ), + self.get_feature_view_sync: gapic_v1.method_async.wrap_method( + self.get_feature_view_sync, + default_timeout=None, + client_info=client_info, + ), + self.list_feature_view_syncs: gapic_v1.method_async.wrap_method( + self.list_feature_view_syncs, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_service/async_client.py b/google/cloud/aiplatform_v1/services/feature_online_store_service/async_client.py index 264a9d6581..24c2058ee3 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -208,7 +209,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, FeatureOnlineStoreServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + FeatureOnlineStoreServiceTransport, + Callable[..., FeatureOnlineStoreServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -220,9 +227,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeatureOnlineStoreServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureOnlineStoreServiceTransport,Callable[..., FeatureOnlineStoreServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureOnlineStoreServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -342,8 +351,8 @@ async def sample_fetch_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view, data_key]) if request is not None and has_flattened_params: raise ValueError( @@ -351,7 +360,12 @@ async def sample_fetch_feature_values(): "the individual field arguments should be set." ) - request = feature_online_store_service.FetchFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_service.FetchFeatureValuesRequest + ): + request = feature_online_store_service.FetchFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -362,11 +376,9 @@ async def sample_fetch_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.fetch_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.fetch_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -452,15 +464,18 @@ async def sample_search_nearest_entities(): """ # Create or coerce a protobuf request object. - request = feature_online_store_service.SearchNearestEntitiesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_service.SearchNearestEntitiesRequest + ): + request = feature_online_store_service.SearchNearestEntitiesRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_nearest_entities, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_nearest_entities + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_service/client.py b/google/cloud/aiplatform_v1/services/feature_online_store_service/client.py index 8bb2c9acc6..476fdddf0f 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_service/client.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -536,7 +537,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, FeatureOnlineStoreServiceTransport]] = None, + transport: Optional[ + Union[ + str, + FeatureOnlineStoreServiceTransport, + Callable[..., FeatureOnlineStoreServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -548,9 +555,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeatureOnlineStoreServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureOnlineStoreServiceTransport,Callable[..., FeatureOnlineStoreServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureOnlineStoreServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -664,8 +673,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeatureOnlineStoreServiceTransport], + Callable[..., FeatureOnlineStoreServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., FeatureOnlineStoreServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -750,8 +767,8 @@ def sample_fetch_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view, data_key]) if request is not None and has_flattened_params: raise ValueError( @@ -759,10 +776,8 @@ def sample_fetch_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_service.FetchFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_service.FetchFeatureValuesRequest ): @@ -862,10 +877,8 @@ def sample_search_nearest_entities(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_service.SearchNearestEntitiesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_service.SearchNearestEntitiesRequest ): diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/feature_online_store_service/transports/grpc.py index 61543be966..e81cf57f3e 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_service/transports/grpc.py @@ -54,7 +54,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -74,14 +74,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -91,11 +94,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -121,7 +124,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -162,7 +165,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/feature_online_store_service/transports/grpc_asyncio.py index decf80913d..db1ea2ec2f 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -69,7 +71,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -99,7 +100,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -119,15 +120,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -137,11 +141,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -167,7 +171,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -207,7 +211,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -297,6 +303,21 @@ def search_nearest_entities( ) return self._stubs["search_nearest_entities"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.fetch_feature_values: gapic_v1.method_async.wrap_method( + self.fetch_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.search_nearest_entities: gapic_v1.method_async.wrap_method( + self.search_nearest_entities, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/feature_registry_service/async_client.py b/google/cloud/aiplatform_v1/services/feature_registry_service/async_client.py index 514961e126..2a2e28dc1d 100644 --- a/google/cloud/aiplatform_v1/services/feature_registry_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/feature_registry_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -217,7 +218,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, FeatureRegistryServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + FeatureRegistryServiceTransport, + Callable[..., FeatureRegistryServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -229,9 +236,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeatureRegistryServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureRegistryServiceTransport,Callable[..., FeatureRegistryServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureRegistryServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -336,7 +345,7 @@ async def sample_create_feature_group(): parent (:class:`str`): Required. The resource name of the Location to create FeatureGroups. Format: - ``projects/{project}/locations/{location}'`` + ``projects/{project}/locations/{location}`` This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this @@ -377,8 +386,8 @@ async def sample_create_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature_group, feature_group_id]) if request is not None and has_flattened_params: raise ValueError( @@ -386,7 +395,10 @@ async def sample_create_feature_group(): "the individual field arguments should be set." ) - request = feature_registry_service.CreateFeatureGroupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.CreateFeatureGroupRequest): + request = feature_registry_service.CreateFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -399,11 +411,9 @@ async def sample_create_feature_group(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature_group, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature_group + ] # Certain fields should be provided within the metadata header; # add these here. @@ -494,8 +504,8 @@ async def sample_get_feature_group(): Vertex AI Feature Group. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -503,7 +513,10 @@ async def sample_get_feature_group(): "the individual field arguments should be set." ) - request = feature_registry_service.GetFeatureGroupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.GetFeatureGroupRequest): + request = feature_registry_service.GetFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -512,11 +525,9 @@ async def sample_get_feature_group(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature_group, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature_group + ] # Certain fields should be provided within the metadata header; # add these here. @@ -606,8 +617,8 @@ async def sample_list_feature_groups(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -615,7 +626,10 @@ async def sample_list_feature_groups(): "the individual field arguments should be set." ) - request = feature_registry_service.ListFeatureGroupsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.ListFeatureGroupsRequest): + request = feature_registry_service.ListFeatureGroupsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -624,11 +638,9 @@ async def sample_list_feature_groups(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_feature_groups, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_feature_groups + ] # Certain fields should be provided within the metadata header; # add these here. @@ -751,8 +763,8 @@ async def sample_update_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_group, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -760,7 +772,10 @@ async def sample_update_feature_group(): "the individual field arguments should be set." ) - request = feature_registry_service.UpdateFeatureGroupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.UpdateFeatureGroupRequest): + request = feature_registry_service.UpdateFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -771,11 +786,9 @@ async def sample_update_feature_group(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature_group, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature_group + ] # Certain fields should be provided within the metadata header; # add these here. @@ -895,8 +908,8 @@ async def sample_delete_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -904,7 +917,10 @@ async def sample_delete_feature_group(): "the individual field arguments should be set." ) - request = feature_registry_service.DeleteFeatureGroupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.DeleteFeatureGroupRequest): + request = feature_registry_service.DeleteFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -915,11 +931,9 @@ async def sample_delete_feature_group(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature_group, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature_group + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1048,8 +1062,8 @@ async def sample_create_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1057,7 +1071,10 @@ async def sample_create_feature(): "the individual field arguments should be set." ) - request = featurestore_service.CreateFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.CreateFeatureRequest): + request = featurestore_service.CreateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1070,11 +1087,9 @@ async def sample_create_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1171,8 +1186,8 @@ async def sample_get_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1180,7 +1195,10 @@ async def sample_get_feature(): "the individual field arguments should be set." ) - request = featurestore_service.GetFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.GetFeatureRequest): + request = featurestore_service.GetFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1189,11 +1207,9 @@ async def sample_get_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1287,8 +1303,8 @@ async def sample_list_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1296,7 +1312,10 @@ async def sample_list_features(): "the individual field arguments should be set." ) - request = featurestore_service.ListFeaturesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ListFeaturesRequest): + request = featurestore_service.ListFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1305,11 +1324,9 @@ async def sample_list_features(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_features, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_features + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1434,8 +1451,8 @@ async def sample_update_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1443,7 +1460,10 @@ async def sample_update_feature(): "the individual field arguments should be set." ) - request = featurestore_service.UpdateFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.UpdateFeatureRequest): + request = featurestore_service.UpdateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1454,11 +1474,9 @@ async def sample_update_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1571,8 +1589,8 @@ async def sample_delete_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1580,7 +1598,10 @@ async def sample_delete_feature(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteFeatureRequest): + request = featurestore_service.DeleteFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1589,11 +1610,9 @@ async def sample_delete_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/feature_registry_service/client.py b/google/cloud/aiplatform_v1/services/feature_registry_service/client.py index 9bb13b6eee..d4fbe567fa 100644 --- a/google/cloud/aiplatform_v1/services/feature_registry_service/client.py +++ b/google/cloud/aiplatform_v1/services/feature_registry_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -573,7 +574,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, FeatureRegistryServiceTransport]] = None, + transport: Optional[ + Union[ + str, + FeatureRegistryServiceTransport, + Callable[..., FeatureRegistryServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -585,9 +592,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeatureRegistryServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureRegistryServiceTransport,Callable[..., FeatureRegistryServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureRegistryServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -699,8 +708,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeatureRegistryServiceTransport], + Callable[..., FeatureRegistryServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., FeatureRegistryServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -770,7 +787,7 @@ def sample_create_feature_group(): parent (str): Required. The resource name of the Location to create FeatureGroups. Format: - ``projects/{project}/locations/{location}'`` + ``projects/{project}/locations/{location}`` This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this @@ -811,8 +828,8 @@ def sample_create_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature_group, feature_group_id]) if request is not None and has_flattened_params: raise ValueError( @@ -820,10 +837,8 @@ def sample_create_feature_group(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.CreateFeatureGroupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.CreateFeatureGroupRequest): request = feature_registry_service.CreateFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -928,8 +943,8 @@ def sample_get_feature_group(): Vertex AI Feature Group. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -937,10 +952,8 @@ def sample_get_feature_group(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.GetFeatureGroupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.GetFeatureGroupRequest): request = feature_registry_service.GetFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1040,8 +1053,8 @@ def sample_list_feature_groups(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1049,10 +1062,8 @@ def sample_list_feature_groups(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.ListFeatureGroupsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.ListFeatureGroupsRequest): request = feature_registry_service.ListFeatureGroupsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1185,8 +1196,8 @@ def sample_update_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_group, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1194,10 +1205,8 @@ def sample_update_feature_group(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.UpdateFeatureGroupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.UpdateFeatureGroupRequest): request = feature_registry_service.UpdateFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1329,8 +1338,8 @@ def sample_delete_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -1338,10 +1347,8 @@ def sample_delete_feature_group(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.DeleteFeatureGroupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.DeleteFeatureGroupRequest): request = feature_registry_service.DeleteFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1482,8 +1489,8 @@ def sample_create_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1491,10 +1498,8 @@ def sample_create_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.CreateFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.CreateFeatureRequest): request = featurestore_service.CreateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1605,8 +1610,8 @@ def sample_get_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1614,10 +1619,8 @@ def sample_get_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.GetFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.GetFeatureRequest): request = featurestore_service.GetFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1721,8 +1724,8 @@ def sample_list_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1730,10 +1733,8 @@ def sample_list_features(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ListFeaturesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ListFeaturesRequest): request = featurestore_service.ListFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1868,8 +1869,8 @@ def sample_update_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1877,10 +1878,8 @@ def sample_update_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.UpdateFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.UpdateFeatureRequest): request = featurestore_service.UpdateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2005,8 +2004,8 @@ def sample_delete_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2014,10 +2013,8 @@ def sample_delete_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteFeatureRequest): request = featurestore_service.DeleteFeatureRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/feature_registry_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/feature_registry_service/transports/grpc.py index 28129c9dfa..98ed8e0e88 100644 --- a/google/cloud/aiplatform_v1/services/feature_registry_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/feature_registry_service/transports/grpc.py @@ -59,7 +59,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -79,14 +79,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -96,11 +99,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -127,7 +130,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -168,7 +171,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/feature_registry_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/feature_registry_service/transports/grpc_asyncio.py index d47be7da02..9692da7bfd 100644 --- a/google/cloud/aiplatform_v1/services/feature_registry_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/feature_registry_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -74,7 +76,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -104,7 +105,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -124,15 +125,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -142,11 +146,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -173,7 +177,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -213,7 +217,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -543,6 +549,61 @@ def delete_feature( ) return self._stubs["delete_feature"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_feature_group: gapic_v1.method_async.wrap_method( + self.create_feature_group, + default_timeout=None, + client_info=client_info, + ), + self.get_feature_group: gapic_v1.method_async.wrap_method( + self.get_feature_group, + default_timeout=None, + client_info=client_info, + ), + self.list_feature_groups: gapic_v1.method_async.wrap_method( + self.list_feature_groups, + default_timeout=None, + client_info=client_info, + ), + self.update_feature_group: gapic_v1.method_async.wrap_method( + self.update_feature_group, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature_group: gapic_v1.method_async.wrap_method( + self.delete_feature_group, + default_timeout=None, + client_info=client_info, + ), + self.create_feature: gapic_v1.method_async.wrap_method( + self.create_feature, + default_timeout=None, + client_info=client_info, + ), + self.get_feature: gapic_v1.method_async.wrap_method( + self.get_feature, + default_timeout=None, + client_info=client_info, + ), + self.list_features: gapic_v1.method_async.wrap_method( + self.list_features, + default_timeout=None, + client_info=client_info, + ), + self.update_feature: gapic_v1.method_async.wrap_method( + self.update_feature, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature: gapic_v1.method_async.wrap_method( + self.delete_feature, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py index 68304b0b8f..0f71befed6 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -216,8 +217,12 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[ - str, FeaturestoreOnlineServingServiceTransport + transport: Optional[ + Union[ + str, + FeaturestoreOnlineServingServiceTransport, + Callable[..., FeaturestoreOnlineServingServiceTransport], + ] ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, @@ -230,9 +235,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeaturestoreOnlineServingServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeaturestoreOnlineServingServiceTransport,Callable[..., FeaturestoreOnlineServingServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeaturestoreOnlineServingServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -354,8 +361,8 @@ async def sample_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -363,7 +370,12 @@ async def sample_read_feature_values(): "the individual field arguments should be set." ) - request = featurestore_online_service.ReadFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, featurestore_online_service.ReadFeatureValuesRequest + ): + request = featurestore_online_service.ReadFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -372,11 +384,9 @@ async def sample_read_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -477,8 +487,8 @@ async def sample_streaming_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -486,7 +496,14 @@ async def sample_streaming_read_feature_values(): "the individual field arguments should be set." ) - request = featurestore_online_service.StreamingReadFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, featurestore_online_service.StreamingReadFeatureValuesRequest + ): + request = featurestore_online_service.StreamingReadFeatureValuesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -495,11 +512,9 @@ async def sample_streaming_read_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.streaming_read_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.streaming_read_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -608,8 +623,8 @@ async def sample_write_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, payloads]) if request is not None and has_flattened_params: raise ValueError( @@ -617,7 +632,12 @@ async def sample_write_feature_values(): "the individual field arguments should be set." ) - request = featurestore_online_service.WriteFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, featurestore_online_service.WriteFeatureValuesRequest + ): + request = featurestore_online_service.WriteFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -628,11 +648,9 @@ async def sample_write_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.write_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.write_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py index aa381759b5..ddcd7aaba1 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -544,7 +545,11 @@ def __init__( *, credentials: Optional[ga_credentials.Credentials] = None, transport: Optional[ - Union[str, FeaturestoreOnlineServingServiceTransport] + Union[ + str, + FeaturestoreOnlineServingServiceTransport, + Callable[..., FeaturestoreOnlineServingServiceTransport], + ] ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, @@ -557,9 +562,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeaturestoreOnlineServingServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeaturestoreOnlineServingServiceTransport,Callable[..., FeaturestoreOnlineServingServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeaturestoreOnlineServingServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -677,8 +684,18 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeaturestoreOnlineServingServiceTransport], + Callable[..., FeaturestoreOnlineServingServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast( + Callable[..., FeaturestoreOnlineServingServiceTransport], transport + ) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -765,8 +782,8 @@ def sample_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -774,10 +791,8 @@ def sample_read_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_online_service.ReadFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, featurestore_online_service.ReadFeatureValuesRequest ): @@ -888,8 +903,8 @@ def sample_streaming_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -897,10 +912,8 @@ def sample_streaming_read_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_online_service.StreamingReadFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, featurestore_online_service.StreamingReadFeatureValuesRequest ): @@ -1025,8 +1038,8 @@ def sample_write_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, payloads]) if request is not None and has_flattened_params: raise ValueError( @@ -1034,10 +1047,8 @@ def sample_write_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_online_service.WriteFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, featurestore_online_service.WriteFeatureValuesRequest ): diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py index f3caf2377b..e95345ca7c 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py @@ -56,7 +56,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -76,14 +76,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -93,11 +96,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -123,7 +126,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -164,7 +167,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py index ed9927fd35..115cd13090 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -71,7 +73,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -101,7 +102,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -121,15 +122,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -139,11 +143,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -169,7 +173,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -209,7 +213,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -336,6 +342,26 @@ def write_feature_values( ) return self._stubs["write_feature_values"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.read_feature_values: gapic_v1.method_async.wrap_method( + self.read_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.streaming_read_feature_values: gapic_v1.method_async.wrap_method( + self.streaming_read_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.write_feature_values: gapic_v1.method_async.wrap_method( + self.write_feature_values, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py index 4d6e47e09b..aed331944a 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -222,7 +223,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, FeaturestoreServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + FeaturestoreServiceTransport, + Callable[..., FeaturestoreServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -234,9 +241,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeaturestoreServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeaturestoreServiceTransport,Callable[..., FeaturestoreServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeaturestoreServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -379,8 +388,8 @@ async def sample_create_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore, featurestore_id]) if request is not None and has_flattened_params: raise ValueError( @@ -388,7 +397,10 @@ async def sample_create_featurestore(): "the individual field arguments should be set." ) - request = featurestore_service.CreateFeaturestoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.CreateFeaturestoreRequest): + request = featurestore_service.CreateFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -401,11 +413,9 @@ async def sample_create_featurestore(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_featurestore, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_featurestore + ] # Certain fields should be provided within the metadata header; # add these here. @@ -501,8 +511,8 @@ async def sample_get_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -510,7 +520,10 @@ async def sample_get_featurestore(): "the individual field arguments should be set." ) - request = featurestore_service.GetFeaturestoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.GetFeaturestoreRequest): + request = featurestore_service.GetFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -519,11 +532,9 @@ async def sample_get_featurestore(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_featurestore, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_featurestore + ] # Certain fields should be provided within the metadata header; # add these here. @@ -613,8 +624,8 @@ async def sample_list_featurestores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -622,7 +633,10 @@ async def sample_list_featurestores(): "the individual field arguments should be set." ) - request = featurestore_service.ListFeaturestoresRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ListFeaturestoresRequest): + request = featurestore_service.ListFeaturestoresRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -631,11 +645,9 @@ async def sample_list_featurestores(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_featurestores, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_featurestores + ] # Certain fields should be provided within the metadata header; # add these here. @@ -758,8 +770,8 @@ async def sample_update_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -767,7 +779,10 @@ async def sample_update_featurestore(): "the individual field arguments should be set." ) - request = featurestore_service.UpdateFeaturestoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.UpdateFeaturestoreRequest): + request = featurestore_service.UpdateFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -778,11 +793,9 @@ async def sample_update_featurestore(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_featurestore, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_featurestore + ] # Certain fields should be provided within the metadata header; # add these here. @@ -905,8 +918,8 @@ async def sample_delete_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -914,7 +927,10 @@ async def sample_delete_featurestore(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteFeaturestoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteFeaturestoreRequest): + request = featurestore_service.DeleteFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -925,11 +941,9 @@ async def sample_delete_featurestore(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_featurestore, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_featurestore + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1053,8 +1067,8 @@ async def sample_create_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type, entity_type_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1062,7 +1076,10 @@ async def sample_create_entity_type(): "the individual field arguments should be set." ) - request = featurestore_service.CreateEntityTypeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.CreateEntityTypeRequest): + request = featurestore_service.CreateEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1075,11 +1092,9 @@ async def sample_create_entity_type(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_entity_type, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_entity_type + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1176,8 +1191,8 @@ async def sample_get_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1185,7 +1200,10 @@ async def sample_get_entity_type(): "the individual field arguments should be set." ) - request = featurestore_service.GetEntityTypeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.GetEntityTypeRequest): + request = featurestore_service.GetEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1194,11 +1212,9 @@ async def sample_get_entity_type(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_entity_type, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_entity_type + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1288,8 +1304,8 @@ async def sample_list_entity_types(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1297,7 +1313,10 @@ async def sample_list_entity_types(): "the individual field arguments should be set." ) - request = featurestore_service.ListEntityTypesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ListEntityTypesRequest): + request = featurestore_service.ListEntityTypesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1306,11 +1325,9 @@ async def sample_list_entity_types(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_entity_types, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_entity_types + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1435,8 +1452,8 @@ async def sample_update_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1444,7 +1461,10 @@ async def sample_update_entity_type(): "the individual field arguments should be set." ) - request = featurestore_service.UpdateEntityTypeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.UpdateEntityTypeRequest): + request = featurestore_service.UpdateEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1455,11 +1475,9 @@ async def sample_update_entity_type(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_entity_type, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_entity_type + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1573,8 +1591,8 @@ async def sample_delete_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -1582,7 +1600,10 @@ async def sample_delete_entity_type(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteEntityTypeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteEntityTypeRequest): + request = featurestore_service.DeleteEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1593,11 +1614,9 @@ async def sample_delete_entity_type(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_entity_type, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_entity_type + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1726,8 +1745,8 @@ async def sample_create_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1735,7 +1754,10 @@ async def sample_create_feature(): "the individual field arguments should be set." ) - request = featurestore_service.CreateFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.CreateFeatureRequest): + request = featurestore_service.CreateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1748,11 +1770,9 @@ async def sample_create_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1871,8 +1891,8 @@ async def sample_batch_create_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -1880,7 +1900,10 @@ async def sample_batch_create_features(): "the individual field arguments should be set." ) - request = featurestore_service.BatchCreateFeaturesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.BatchCreateFeaturesRequest): + request = featurestore_service.BatchCreateFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1891,11 +1914,9 @@ async def sample_batch_create_features(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_create_features, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_create_features + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1992,8 +2013,8 @@ async def sample_get_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2001,7 +2022,10 @@ async def sample_get_feature(): "the individual field arguments should be set." ) - request = featurestore_service.GetFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.GetFeatureRequest): + request = featurestore_service.GetFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2010,11 +2034,9 @@ async def sample_get_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2108,8 +2130,8 @@ async def sample_list_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2117,7 +2139,10 @@ async def sample_list_features(): "the individual field arguments should be set." ) - request = featurestore_service.ListFeaturesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ListFeaturesRequest): + request = featurestore_service.ListFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2126,11 +2151,9 @@ async def sample_list_features(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_features, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_features + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2249,8 +2272,8 @@ async def sample_update_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2258,7 +2281,10 @@ async def sample_update_feature(): "the individual field arguments should be set." ) - request = featurestore_service.UpdateFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.UpdateFeatureRequest): + request = featurestore_service.UpdateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2269,11 +2295,9 @@ async def sample_update_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2378,8 +2402,8 @@ async def sample_delete_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2387,7 +2411,10 @@ async def sample_delete_feature(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteFeatureRequest): + request = featurestore_service.DeleteFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2396,11 +2423,9 @@ async def sample_delete_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2534,8 +2559,8 @@ async def sample_import_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -2543,7 +2568,10 @@ async def sample_import_feature_values(): "the individual field arguments should be set." ) - request = featurestore_service.ImportFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ImportFeatureValuesRequest): + request = featurestore_service.ImportFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2552,11 +2580,9 @@ async def sample_import_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2677,8 +2703,8 @@ async def sample_batch_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: raise ValueError( @@ -2686,7 +2712,10 @@ async def sample_batch_read_feature_values(): "the individual field arguments should be set." ) - request = featurestore_service.BatchReadFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.BatchReadFeatureValuesRequest): + request = featurestore_service.BatchReadFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2695,11 +2724,9 @@ async def sample_batch_read_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_read_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_read_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2810,8 +2837,8 @@ async def sample_export_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -2819,7 +2846,10 @@ async def sample_export_feature_values(): "the individual field arguments should be set." ) - request = featurestore_service.ExportFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ExportFeatureValuesRequest): + request = featurestore_service.ExportFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2828,11 +2858,9 @@ async def sample_export_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2951,8 +2979,8 @@ async def sample_delete_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -2960,7 +2988,10 @@ async def sample_delete_feature_values(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteFeatureValuesRequest): + request = featurestore_service.DeleteFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2969,11 +3000,9 @@ async def sample_delete_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3150,8 +3179,8 @@ async def sample_search_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([location, query]) if request is not None and has_flattened_params: raise ValueError( @@ -3159,7 +3188,10 @@ async def sample_search_features(): "the individual field arguments should be set." ) - request = featurestore_service.SearchFeaturesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.SearchFeaturesRequest): + request = featurestore_service.SearchFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3170,11 +3202,9 @@ async def sample_search_features(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_features, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_features + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1/services/featurestore_service/client.py index 655bb2fc6b..39822417e9 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -598,7 +599,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, FeaturestoreServiceTransport]] = None, + transport: Optional[ + Union[ + str, + FeaturestoreServiceTransport, + Callable[..., FeaturestoreServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -610,9 +617,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeaturestoreServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeaturestoreServiceTransport,Callable[..., FeaturestoreServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeaturestoreServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -724,8 +733,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeaturestoreServiceTransport], + Callable[..., FeaturestoreServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., FeaturestoreServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -833,8 +850,8 @@ def sample_create_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore, featurestore_id]) if request is not None and has_flattened_params: raise ValueError( @@ -842,10 +859,8 @@ def sample_create_featurestore(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.CreateFeaturestoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.CreateFeaturestoreRequest): request = featurestore_service.CreateFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -955,8 +970,8 @@ def sample_get_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -964,10 +979,8 @@ def sample_get_featurestore(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.GetFeaturestoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.GetFeaturestoreRequest): request = featurestore_service.GetFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1067,8 +1080,8 @@ def sample_list_featurestores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1076,10 +1089,8 @@ def sample_list_featurestores(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ListFeaturestoresRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ListFeaturestoresRequest): request = featurestore_service.ListFeaturestoresRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1212,8 +1223,8 @@ def sample_update_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1221,10 +1232,8 @@ def sample_update_featurestore(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.UpdateFeaturestoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.UpdateFeaturestoreRequest): request = featurestore_service.UpdateFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1359,8 +1368,8 @@ def sample_delete_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -1368,10 +1377,8 @@ def sample_delete_featurestore(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteFeaturestoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteFeaturestoreRequest): request = featurestore_service.DeleteFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1507,8 +1514,8 @@ def sample_create_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type, entity_type_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1516,10 +1523,8 @@ def sample_create_entity_type(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.CreateEntityTypeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.CreateEntityTypeRequest): request = featurestore_service.CreateEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1630,8 +1635,8 @@ def sample_get_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1639,10 +1644,8 @@ def sample_get_entity_type(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.GetEntityTypeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.GetEntityTypeRequest): request = featurestore_service.GetEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1742,8 +1745,8 @@ def sample_list_entity_types(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1751,10 +1754,8 @@ def sample_list_entity_types(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ListEntityTypesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ListEntityTypesRequest): request = featurestore_service.ListEntityTypesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1889,8 +1890,8 @@ def sample_update_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1898,10 +1899,8 @@ def sample_update_entity_type(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.UpdateEntityTypeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.UpdateEntityTypeRequest): request = featurestore_service.UpdateEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2027,8 +2026,8 @@ def sample_delete_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -2036,10 +2035,8 @@ def sample_delete_entity_type(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteEntityTypeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteEntityTypeRequest): request = featurestore_service.DeleteEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2180,8 +2177,8 @@ def sample_create_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: raise ValueError( @@ -2189,10 +2186,8 @@ def sample_create_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.CreateFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.CreateFeatureRequest): request = featurestore_service.CreateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2325,8 +2320,8 @@ def sample_batch_create_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -2334,10 +2329,8 @@ def sample_batch_create_features(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.BatchCreateFeaturesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.BatchCreateFeaturesRequest): request = featurestore_service.BatchCreateFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2446,8 +2439,8 @@ def sample_get_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2455,10 +2448,8 @@ def sample_get_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.GetFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.GetFeatureRequest): request = featurestore_service.GetFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2562,8 +2553,8 @@ def sample_list_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2571,10 +2562,8 @@ def sample_list_features(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ListFeaturesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ListFeaturesRequest): request = featurestore_service.ListFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2703,8 +2692,8 @@ def sample_update_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2712,10 +2701,8 @@ def sample_update_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.UpdateFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.UpdateFeatureRequest): request = featurestore_service.UpdateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2832,8 +2819,8 @@ def sample_delete_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2841,10 +2828,8 @@ def sample_delete_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteFeatureRequest): request = featurestore_service.DeleteFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2988,8 +2973,8 @@ def sample_import_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -2997,10 +2982,8 @@ def sample_import_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ImportFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ImportFeatureValuesRequest): request = featurestore_service.ImportFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3131,8 +3114,8 @@ def sample_batch_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: raise ValueError( @@ -3140,10 +3123,8 @@ def sample_batch_read_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.BatchReadFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.BatchReadFeatureValuesRequest): request = featurestore_service.BatchReadFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3266,8 +3247,8 @@ def sample_export_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -3275,10 +3256,8 @@ def sample_export_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ExportFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ExportFeatureValuesRequest): request = featurestore_service.ExportFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3407,8 +3386,8 @@ def sample_delete_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -3416,10 +3395,8 @@ def sample_delete_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteFeatureValuesRequest): request = featurestore_service.DeleteFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3606,8 +3583,8 @@ def sample_search_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([location, query]) if request is not None and has_flattened_params: raise ValueError( @@ -3615,10 +3592,8 @@ def sample_search_features(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.SearchFeaturesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.SearchFeaturesRequest): request = featurestore_service.SearchFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc.py index 4ff9f04a39..e2aa6e1ec3 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc.py @@ -61,7 +61,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -81,14 +81,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -98,11 +101,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -129,7 +132,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -170,7 +173,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc_asyncio.py index 8e6f143a0a..84252a70d7 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -76,7 +78,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -106,7 +107,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -126,15 +127,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -144,11 +148,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -175,7 +179,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -215,7 +219,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -911,6 +917,116 @@ def search_features( ) return self._stubs["search_features"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_featurestore: gapic_v1.method_async.wrap_method( + self.create_featurestore, + default_timeout=None, + client_info=client_info, + ), + self.get_featurestore: gapic_v1.method_async.wrap_method( + self.get_featurestore, + default_timeout=None, + client_info=client_info, + ), + self.list_featurestores: gapic_v1.method_async.wrap_method( + self.list_featurestores, + default_timeout=None, + client_info=client_info, + ), + self.update_featurestore: gapic_v1.method_async.wrap_method( + self.update_featurestore, + default_timeout=None, + client_info=client_info, + ), + self.delete_featurestore: gapic_v1.method_async.wrap_method( + self.delete_featurestore, + default_timeout=None, + client_info=client_info, + ), + self.create_entity_type: gapic_v1.method_async.wrap_method( + self.create_entity_type, + default_timeout=None, + client_info=client_info, + ), + self.get_entity_type: gapic_v1.method_async.wrap_method( + self.get_entity_type, + default_timeout=None, + client_info=client_info, + ), + self.list_entity_types: gapic_v1.method_async.wrap_method( + self.list_entity_types, + default_timeout=None, + client_info=client_info, + ), + self.update_entity_type: gapic_v1.method_async.wrap_method( + self.update_entity_type, + default_timeout=None, + client_info=client_info, + ), + self.delete_entity_type: gapic_v1.method_async.wrap_method( + self.delete_entity_type, + default_timeout=None, + client_info=client_info, + ), + self.create_feature: gapic_v1.method_async.wrap_method( + self.create_feature, + default_timeout=None, + client_info=client_info, + ), + self.batch_create_features: gapic_v1.method_async.wrap_method( + self.batch_create_features, + default_timeout=None, + client_info=client_info, + ), + self.get_feature: gapic_v1.method_async.wrap_method( + self.get_feature, + default_timeout=None, + client_info=client_info, + ), + self.list_features: gapic_v1.method_async.wrap_method( + self.list_features, + default_timeout=None, + client_info=client_info, + ), + self.update_feature: gapic_v1.method_async.wrap_method( + self.update_feature, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature: gapic_v1.method_async.wrap_method( + self.delete_feature, + default_timeout=None, + client_info=client_info, + ), + self.import_feature_values: gapic_v1.method_async.wrap_method( + self.import_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.batch_read_feature_values: gapic_v1.method_async.wrap_method( + self.batch_read_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.export_feature_values: gapic_v1.method_async.wrap_method( + self.export_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature_values: gapic_v1.method_async.wrap_method( + self.delete_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.search_features: gapic_v1.method_async.wrap_method( + self.search_features, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/async_client.py b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/async_client.py index 22873bd4c3..9fb81b9ddd 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -209,7 +210,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, GenAiTuningServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + GenAiTuningServiceTransport, + Callable[..., GenAiTuningServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -221,9 +228,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.GenAiTuningServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,GenAiTuningServiceTransport,Callable[..., GenAiTuningServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the GenAiTuningServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -346,8 +355,8 @@ async def sample_create_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tuning_job]) if request is not None and has_flattened_params: raise ValueError( @@ -355,7 +364,10 @@ async def sample_create_tuning_job(): "the individual field arguments should be set." ) - request = genai_tuning_service.CreateTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, genai_tuning_service.CreateTuningJobRequest): + request = genai_tuning_service.CreateTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -366,11 +378,9 @@ async def sample_create_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tuning_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -453,8 +463,8 @@ async def sample_get_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -462,7 +472,10 @@ async def sample_get_tuning_job(): "the individual field arguments should be set." ) - request = genai_tuning_service.GetTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, genai_tuning_service.GetTuningJobRequest): + request = genai_tuning_service.GetTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -471,11 +484,9 @@ async def sample_get_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tuning_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -565,8 +576,8 @@ async def sample_list_tuning_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -574,7 +585,10 @@ async def sample_list_tuning_jobs(): "the individual field arguments should be set." ) - request = genai_tuning_service.ListTuningJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, genai_tuning_service.ListTuningJobsRequest): + request = genai_tuning_service.ListTuningJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -583,11 +597,9 @@ async def sample_list_tuning_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tuning_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -684,8 +696,8 @@ async def sample_cancel_tuning_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -693,7 +705,10 @@ async def sample_cancel_tuning_job(): "the individual field arguments should be set." ) - request = genai_tuning_service.CancelTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, genai_tuning_service.CancelTuningJobRequest): + request = genai_tuning_service.CancelTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -702,11 +717,9 @@ async def sample_cancel_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_tuning_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/client.py b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/client.py index 38dcca44fc..6680eb9319 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/client.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -605,7 +606,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, GenAiTuningServiceTransport]] = None, + transport: Optional[ + Union[ + str, + GenAiTuningServiceTransport, + Callable[..., GenAiTuningServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -617,9 +624,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, GenAiTuningServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,GenAiTuningServiceTransport,Callable[..., GenAiTuningServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the GenAiTuningServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -731,8 +740,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[GenAiTuningServiceTransport], + Callable[..., GenAiTuningServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., GenAiTuningServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -820,8 +837,8 @@ def sample_create_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tuning_job]) if request is not None and has_flattened_params: raise ValueError( @@ -829,10 +846,8 @@ def sample_create_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a genai_tuning_service.CreateTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, genai_tuning_service.CreateTuningJobRequest): request = genai_tuning_service.CreateTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -927,8 +942,8 @@ def sample_get_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -936,10 +951,8 @@ def sample_get_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a genai_tuning_service.GetTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, genai_tuning_service.GetTuningJobRequest): request = genai_tuning_service.GetTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1039,8 +1052,8 @@ def sample_list_tuning_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1048,10 +1061,8 @@ def sample_list_tuning_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a genai_tuning_service.ListTuningJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, genai_tuning_service.ListTuningJobsRequest): request = genai_tuning_service.ListTuningJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1158,8 +1169,8 @@ def sample_cancel_tuning_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1167,10 +1178,8 @@ def sample_cancel_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a genai_tuning_service.CancelTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, genai_tuning_service.CancelTuningJobRequest): request = genai_tuning_service.CancelTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/transports/grpc.py index 17349e1ccd..7f4d78def9 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -124,7 +127,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -165,7 +168,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/transports/grpc_asyncio.py index 281195f55d..3da68a1632 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -170,7 +174,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -210,7 +214,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -366,6 +372,31 @@ def cancel_tuning_job( ) return self._stubs["cancel_tuning_job"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_tuning_job: gapic_v1.method_async.wrap_method( + self.create_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.get_tuning_job: gapic_v1.method_async.wrap_method( + self.get_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.list_tuning_jobs: gapic_v1.method_async.wrap_method( + self.list_tuning_jobs, + default_timeout=None, + client_info=client_info, + ), + self.cancel_tuning_job: gapic_v1.method_async.wrap_method( + self.cancel_tuning_job, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py index dd1d5c9f63..cfe8d69444 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -212,7 +213,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, IndexEndpointServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + IndexEndpointServiceTransport, + Callable[..., IndexEndpointServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -224,9 +231,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.IndexEndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,IndexEndpointServiceTransport,Callable[..., IndexEndpointServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the IndexEndpointServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -355,8 +364,8 @@ async def sample_create_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: raise ValueError( @@ -364,7 +373,10 @@ async def sample_create_index_endpoint(): "the individual field arguments should be set." ) - request = index_endpoint_service.CreateIndexEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.CreateIndexEndpointRequest): + request = index_endpoint_service.CreateIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -375,11 +387,9 @@ async def sample_create_index_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_index_endpoint, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_index_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -474,8 +484,8 @@ async def sample_get_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -483,7 +493,10 @@ async def sample_get_index_endpoint(): "the individual field arguments should be set." ) - request = index_endpoint_service.GetIndexEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.GetIndexEndpointRequest): + request = index_endpoint_service.GetIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -492,11 +505,9 @@ async def sample_get_index_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_index_endpoint, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_index_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -586,8 +597,8 @@ async def sample_list_index_endpoints(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -595,7 +606,10 @@ async def sample_list_index_endpoints(): "the individual field arguments should be set." ) - request = index_endpoint_service.ListIndexEndpointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.ListIndexEndpointsRequest): + request = index_endpoint_service.ListIndexEndpointsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -604,11 +618,9 @@ async def sample_list_index_endpoints(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_index_endpoints, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_index_endpoints + ] # Certain fields should be provided within the metadata header; # add these here. @@ -714,8 +726,8 @@ async def sample_update_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -723,7 +735,10 @@ async def sample_update_index_endpoint(): "the individual field arguments should be set." ) - request = index_endpoint_service.UpdateIndexEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.UpdateIndexEndpointRequest): + request = index_endpoint_service.UpdateIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -734,11 +749,9 @@ async def sample_update_index_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_index_endpoint, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_index_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -840,8 +853,8 @@ async def sample_delete_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -849,7 +862,10 @@ async def sample_delete_index_endpoint(): "the individual field arguments should be set." ) - request = index_endpoint_service.DeleteIndexEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.DeleteIndexEndpointRequest): + request = index_endpoint_service.DeleteIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -858,11 +874,9 @@ async def sample_delete_index_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_index_endpoint, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_index_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -977,8 +991,8 @@ async def sample_deploy_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: raise ValueError( @@ -986,7 +1000,10 @@ async def sample_deploy_index(): "the individual field arguments should be set." ) - request = index_endpoint_service.DeployIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.DeployIndexRequest): + request = index_endpoint_service.DeployIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -997,11 +1014,9 @@ async def sample_deploy_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.deploy_index, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.deploy_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1114,8 +1129,8 @@ async def sample_undeploy_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1123,7 +1138,10 @@ async def sample_undeploy_index(): "the individual field arguments should be set." ) - request = index_endpoint_service.UndeployIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.UndeployIndexRequest): + request = index_endpoint_service.UndeployIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1134,11 +1152,9 @@ async def sample_undeploy_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.undeploy_index, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.undeploy_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1256,8 +1272,8 @@ async def sample_mutate_deployed_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: raise ValueError( @@ -1265,7 +1281,10 @@ async def sample_mutate_deployed_index(): "the individual field arguments should be set." ) - request = index_endpoint_service.MutateDeployedIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.MutateDeployedIndexRequest): + request = index_endpoint_service.MutateDeployedIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1276,11 +1295,9 @@ async def sample_mutate_deployed_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.mutate_deployed_index, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.mutate_deployed_index + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py index 0b8b84ea90..10534efc91 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -564,7 +565,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, IndexEndpointServiceTransport]] = None, + transport: Optional[ + Union[ + str, + IndexEndpointServiceTransport, + Callable[..., IndexEndpointServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -576,9 +583,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, IndexEndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,IndexEndpointServiceTransport,Callable[..., IndexEndpointServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the IndexEndpointServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -690,8 +699,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[IndexEndpointServiceTransport], + Callable[..., IndexEndpointServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., IndexEndpointServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -785,8 +802,8 @@ def sample_create_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: raise ValueError( @@ -794,10 +811,8 @@ def sample_create_index_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.CreateIndexEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.CreateIndexEndpointRequest): request = index_endpoint_service.CreateIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -904,8 +919,8 @@ def sample_get_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -913,10 +928,8 @@ def sample_get_index_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.GetIndexEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.GetIndexEndpointRequest): request = index_endpoint_service.GetIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1016,8 +1029,8 @@ def sample_list_index_endpoints(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1025,10 +1038,8 @@ def sample_list_index_endpoints(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.ListIndexEndpointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.ListIndexEndpointsRequest): request = index_endpoint_service.ListIndexEndpointsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1144,8 +1155,8 @@ def sample_update_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1153,10 +1164,8 @@ def sample_update_index_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.UpdateIndexEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.UpdateIndexEndpointRequest): request = index_endpoint_service.UpdateIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1270,8 +1279,8 @@ def sample_delete_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1279,10 +1288,8 @@ def sample_delete_index_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.DeleteIndexEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.DeleteIndexEndpointRequest): request = index_endpoint_service.DeleteIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1407,8 +1414,8 @@ def sample_deploy_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: raise ValueError( @@ -1416,10 +1423,8 @@ def sample_deploy_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.DeployIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.DeployIndexRequest): request = index_endpoint_service.DeployIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1544,8 +1549,8 @@ def sample_undeploy_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1553,10 +1558,8 @@ def sample_undeploy_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.UndeployIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.UndeployIndexRequest): request = index_endpoint_service.UndeployIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1686,8 +1689,8 @@ def sample_mutate_deployed_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: raise ValueError( @@ -1695,10 +1698,8 @@ def sample_mutate_deployed_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.MutateDeployedIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.MutateDeployedIndexRequest): request = index_endpoint_service.MutateDeployedIndexRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc.py index 6c3da6bd06..e37f8927fe 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc_asyncio.py index e9526117e0..9981e9b1cb 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -492,6 +498,51 @@ def mutate_deployed_index( ) return self._stubs["mutate_deployed_index"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_index_endpoint: gapic_v1.method_async.wrap_method( + self.create_index_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.get_index_endpoint: gapic_v1.method_async.wrap_method( + self.get_index_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.list_index_endpoints: gapic_v1.method_async.wrap_method( + self.list_index_endpoints, + default_timeout=None, + client_info=client_info, + ), + self.update_index_endpoint: gapic_v1.method_async.wrap_method( + self.update_index_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.delete_index_endpoint: gapic_v1.method_async.wrap_method( + self.delete_index_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.deploy_index: gapic_v1.method_async.wrap_method( + self.deploy_index, + default_timeout=None, + client_info=client_info, + ), + self.undeploy_index: gapic_v1.method_async.wrap_method( + self.undeploy_index, + default_timeout=None, + client_info=client_info, + ), + self.mutate_deployed_index: gapic_v1.method_async.wrap_method( + self.mutate_deployed_index, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/index_service/async_client.py b/google/cloud/aiplatform_v1/services/index_service/async_client.py index bd2912b14c..3268898ed6 100644 --- a/google/cloud/aiplatform_v1/services/index_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/index_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -210,7 +211,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, IndexServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, IndexServiceTransport, Callable[..., IndexServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -222,9 +225,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.IndexServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,IndexServiceTransport,Callable[..., IndexServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the IndexServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -350,8 +355,8 @@ async def sample_create_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: raise ValueError( @@ -359,7 +364,10 @@ async def sample_create_index(): "the individual field arguments should be set." ) - request = index_service.CreateIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.CreateIndexRequest): + request = index_service.CreateIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -370,11 +378,9 @@ async def sample_create_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_index, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -467,8 +473,8 @@ async def sample_get_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -476,7 +482,10 @@ async def sample_get_index(): "the individual field arguments should be set." ) - request = index_service.GetIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.GetIndexRequest): + request = index_service.GetIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -485,11 +494,9 @@ async def sample_get_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_index, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -577,8 +584,8 @@ async def sample_list_indexes(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -586,7 +593,10 @@ async def sample_list_indexes(): "the individual field arguments should be set." ) - request = index_service.ListIndexesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.ListIndexesRequest): + request = index_service.ListIndexesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -595,11 +605,9 @@ async def sample_list_indexes(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_indexes, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_indexes + ] # Certain fields should be provided within the metadata header; # add these here. @@ -710,8 +718,8 @@ async def sample_update_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -719,7 +727,10 @@ async def sample_update_index(): "the individual field arguments should be set." ) - request = index_service.UpdateIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.UpdateIndexRequest): + request = index_service.UpdateIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -730,11 +741,9 @@ async def sample_update_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_index, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -844,8 +853,8 @@ async def sample_delete_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -853,7 +862,10 @@ async def sample_delete_index(): "the individual field arguments should be set." ) - request = index_service.DeleteIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.DeleteIndexRequest): + request = index_service.DeleteIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -862,11 +874,9 @@ async def sample_delete_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_index, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -949,15 +959,16 @@ async def sample_upsert_datapoints(): """ # Create or coerce a protobuf request object. - request = index_service.UpsertDatapointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.UpsertDatapointsRequest): + request = index_service.UpsertDatapointsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.upsert_datapoints, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.upsert_datapoints + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1032,15 +1043,16 @@ async def sample_remove_datapoints(): """ # Create or coerce a protobuf request object. - request = index_service.RemoveDatapointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.RemoveDatapointsRequest): + request = index_service.RemoveDatapointsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.remove_datapoints, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.remove_datapoints + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/index_service/client.py b/google/cloud/aiplatform_v1/services/index_service/client.py index 687b786b63..1157814ac7 100644 --- a/google/cloud/aiplatform_v1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1/services/index_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -565,7 +566,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, IndexServiceTransport]] = None, + transport: Optional[ + Union[str, IndexServiceTransport, Callable[..., IndexServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -577,9 +580,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, IndexServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,IndexServiceTransport,Callable[..., IndexServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the IndexServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -688,8 +693,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[IndexServiceTransport], Callable[..., IndexServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., IndexServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -780,8 +792,8 @@ def sample_create_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: raise ValueError( @@ -789,10 +801,8 @@ def sample_create_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.CreateIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.CreateIndexRequest): request = index_service.CreateIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -897,8 +907,8 @@ def sample_get_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -906,10 +916,8 @@ def sample_get_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.GetIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.GetIndexRequest): request = index_service.GetIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1007,8 +1015,8 @@ def sample_list_indexes(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1016,10 +1024,8 @@ def sample_list_indexes(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.ListIndexesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.ListIndexesRequest): request = index_service.ListIndexesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1140,8 +1146,8 @@ def sample_update_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1149,10 +1155,8 @@ def sample_update_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.UpdateIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.UpdateIndexRequest): request = index_service.UpdateIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1274,8 +1278,8 @@ def sample_delete_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1283,10 +1287,8 @@ def sample_delete_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.DeleteIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.DeleteIndexRequest): request = index_service.DeleteIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1379,10 +1381,8 @@ def sample_upsert_datapoints(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a index_service.UpsertDatapointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.UpsertDatapointsRequest): request = index_service.UpsertDatapointsRequest(request) @@ -1463,10 +1463,8 @@ def sample_remove_datapoints(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a index_service.RemoveDatapointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.RemoveDatapointsRequest): request = index_service.RemoveDatapointsRequest(request) diff --git a/google/cloud/aiplatform_v1/services/index_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/index_service/transports/grpc.py index c7e03fced5..90ffc6fb01 100644 --- a/google/cloud/aiplatform_v1/services/index_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/index_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/index_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/index_service/transports/grpc_asyncio.py index c2096b59d8..fd5374254b 100644 --- a/google/cloud/aiplatform_v1/services/index_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/index_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -454,6 +460,46 @@ def remove_datapoints( ) return self._stubs["remove_datapoints"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_index: gapic_v1.method_async.wrap_method( + self.create_index, + default_timeout=None, + client_info=client_info, + ), + self.get_index: gapic_v1.method_async.wrap_method( + self.get_index, + default_timeout=None, + client_info=client_info, + ), + self.list_indexes: gapic_v1.method_async.wrap_method( + self.list_indexes, + default_timeout=None, + client_info=client_info, + ), + self.update_index: gapic_v1.method_async.wrap_method( + self.update_index, + default_timeout=None, + client_info=client_info, + ), + self.delete_index: gapic_v1.method_async.wrap_method( + self.delete_index, + default_timeout=None, + client_info=client_info, + ), + self.upsert_datapoints: gapic_v1.method_async.wrap_method( + self.upsert_datapoints, + default_timeout=None, + client_info=client_info, + ), + self.remove_datapoints: gapic_v1.method_async.wrap_method( + self.remove_datapoints, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index 505ef38e2b..0c75fff3ee 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -277,7 +278,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, JobServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, JobServiceTransport, Callable[..., JobServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -289,9 +292,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.JobServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,JobServiceTransport,Callable[..., JobServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the JobServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -418,8 +423,8 @@ async def sample_create_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: raise ValueError( @@ -427,7 +432,10 @@ async def sample_create_custom_job(): "the individual field arguments should be set." ) - request = job_service.CreateCustomJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateCustomJobRequest): + request = job_service.CreateCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -438,11 +446,9 @@ async def sample_create_custom_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_custom_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_custom_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -531,8 +537,8 @@ async def sample_get_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -540,7 +546,10 @@ async def sample_get_custom_job(): "the individual field arguments should be set." ) - request = job_service.GetCustomJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetCustomJobRequest): + request = job_service.GetCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -549,11 +558,9 @@ async def sample_get_custom_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_custom_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_custom_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -641,8 +648,8 @@ async def sample_list_custom_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -650,7 +657,10 @@ async def sample_list_custom_jobs(): "the individual field arguments should be set." ) - request = job_service.ListCustomJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListCustomJobsRequest): + request = job_service.ListCustomJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -659,11 +669,9 @@ async def sample_list_custom_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_custom_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_custom_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -770,8 +778,8 @@ async def sample_delete_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -779,7 +787,10 @@ async def sample_delete_custom_job(): "the individual field arguments should be set." ) - request = job_service.DeleteCustomJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteCustomJobRequest): + request = job_service.DeleteCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -788,11 +799,9 @@ async def sample_delete_custom_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_custom_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_custom_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -886,8 +895,8 @@ async def sample_cancel_custom_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -895,7 +904,10 @@ async def sample_cancel_custom_job(): "the individual field arguments should be set." ) - request = job_service.CancelCustomJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelCustomJobRequest): + request = job_service.CancelCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -904,11 +916,9 @@ async def sample_cancel_custom_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_custom_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_custom_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1006,8 +1016,8 @@ async def sample_create_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1015,7 +1025,10 @@ async def sample_create_data_labeling_job(): "the individual field arguments should be set." ) - request = job_service.CreateDataLabelingJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateDataLabelingJobRequest): + request = job_service.CreateDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1026,11 +1039,9 @@ async def sample_create_data_labeling_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_data_labeling_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_data_labeling_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1114,8 +1125,8 @@ async def sample_get_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1123,7 +1134,10 @@ async def sample_get_data_labeling_job(): "the individual field arguments should be set." ) - request = job_service.GetDataLabelingJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetDataLabelingJobRequest): + request = job_service.GetDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1132,11 +1146,9 @@ async def sample_get_data_labeling_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_data_labeling_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_data_labeling_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1223,8 +1235,8 @@ async def sample_list_data_labeling_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1232,7 +1244,10 @@ async def sample_list_data_labeling_jobs(): "the individual field arguments should be set." ) - request = job_service.ListDataLabelingJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListDataLabelingJobsRequest): + request = job_service.ListDataLabelingJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1241,11 +1256,9 @@ async def sample_list_data_labeling_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_data_labeling_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_data_labeling_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1352,8 +1365,8 @@ async def sample_delete_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1361,7 +1374,10 @@ async def sample_delete_data_labeling_job(): "the individual field arguments should be set." ) - request = job_service.DeleteDataLabelingJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteDataLabelingJobRequest): + request = job_service.DeleteDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1370,11 +1386,9 @@ async def sample_delete_data_labeling_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_data_labeling_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_data_labeling_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1457,8 +1471,8 @@ async def sample_cancel_data_labeling_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1466,7 +1480,10 @@ async def sample_cancel_data_labeling_job(): "the individual field arguments should be set." ) - request = job_service.CancelDataLabelingJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelDataLabelingJobRequest): + request = job_service.CancelDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1475,11 +1492,9 @@ async def sample_cancel_data_labeling_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_data_labeling_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_data_labeling_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1586,8 +1601,8 @@ async def sample_create_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1595,7 +1610,10 @@ async def sample_create_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - request = job_service.CreateHyperparameterTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateHyperparameterTuningJobRequest): + request = job_service.CreateHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1606,11 +1624,9 @@ async def sample_create_hyperparameter_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_hyperparameter_tuning_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1698,8 +1714,8 @@ async def sample_get_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1707,7 +1723,10 @@ async def sample_get_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - request = job_service.GetHyperparameterTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetHyperparameterTuningJobRequest): + request = job_service.GetHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1716,11 +1735,9 @@ async def sample_get_hyperparameter_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_hyperparameter_tuning_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1810,8 +1827,8 @@ async def sample_list_hyperparameter_tuning_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1819,7 +1836,10 @@ async def sample_list_hyperparameter_tuning_jobs(): "the individual field arguments should be set." ) - request = job_service.ListHyperparameterTuningJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListHyperparameterTuningJobsRequest): + request = job_service.ListHyperparameterTuningJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1828,11 +1848,9 @@ async def sample_list_hyperparameter_tuning_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_hyperparameter_tuning_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_hyperparameter_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1941,8 +1959,8 @@ async def sample_delete_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1950,7 +1968,10 @@ async def sample_delete_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - request = job_service.DeleteHyperparameterTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteHyperparameterTuningJobRequest): + request = job_service.DeleteHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1959,11 +1980,9 @@ async def sample_delete_hyperparameter_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_hyperparameter_tuning_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2061,8 +2080,8 @@ async def sample_cancel_hyperparameter_tuning_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2070,7 +2089,10 @@ async def sample_cancel_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - request = job_service.CancelHyperparameterTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelHyperparameterTuningJobRequest): + request = job_service.CancelHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2079,11 +2101,9 @@ async def sample_cancel_hyperparameter_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_hyperparameter_tuning_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2177,8 +2197,8 @@ async def sample_create_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, nas_job]) if request is not None and has_flattened_params: raise ValueError( @@ -2186,7 +2206,10 @@ async def sample_create_nas_job(): "the individual field arguments should be set." ) - request = job_service.CreateNasJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateNasJobRequest): + request = job_service.CreateNasJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2197,11 +2220,9 @@ async def sample_create_nas_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_nas_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_nas_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2284,8 +2305,8 @@ async def sample_get_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2293,7 +2314,10 @@ async def sample_get_nas_job(): "the individual field arguments should be set." ) - request = job_service.GetNasJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetNasJobRequest): + request = job_service.GetNasJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2302,11 +2326,9 @@ async def sample_get_nas_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_nas_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_nas_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2394,8 +2416,8 @@ async def sample_list_nas_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2403,7 +2425,10 @@ async def sample_list_nas_jobs(): "the individual field arguments should be set." ) - request = job_service.ListNasJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListNasJobsRequest): + request = job_service.ListNasJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2412,11 +2437,9 @@ async def sample_list_nas_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_nas_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_nas_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2523,8 +2546,8 @@ async def sample_delete_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2532,7 +2555,10 @@ async def sample_delete_nas_job(): "the individual field arguments should be set." ) - request = job_service.DeleteNasJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteNasJobRequest): + request = job_service.DeleteNasJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2541,11 +2567,9 @@ async def sample_delete_nas_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_nas_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_nas_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2639,8 +2663,8 @@ async def sample_cancel_nas_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2648,7 +2672,10 @@ async def sample_cancel_nas_job(): "the individual field arguments should be set." ) - request = job_service.CancelNasJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelNasJobRequest): + request = job_service.CancelNasJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2657,11 +2684,9 @@ async def sample_cancel_nas_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_nas_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_nas_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2744,8 +2769,8 @@ async def sample_get_nas_trial_detail(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2753,7 +2778,10 @@ async def sample_get_nas_trial_detail(): "the individual field arguments should be set." ) - request = job_service.GetNasTrialDetailRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetNasTrialDetailRequest): + request = job_service.GetNasTrialDetailRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2762,11 +2790,9 @@ async def sample_get_nas_trial_detail(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_nas_trial_detail, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_nas_trial_detail + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2853,8 +2879,8 @@ async def sample_list_nas_trial_details(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2862,7 +2888,10 @@ async def sample_list_nas_trial_details(): "the individual field arguments should be set." ) - request = job_service.ListNasTrialDetailsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListNasTrialDetailsRequest): + request = job_service.ListNasTrialDetailsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2871,11 +2900,9 @@ async def sample_list_nas_trial_details(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_nas_trial_details, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_nas_trial_details + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2994,8 +3021,8 @@ async def sample_create_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: raise ValueError( @@ -3003,7 +3030,10 @@ async def sample_create_batch_prediction_job(): "the individual field arguments should be set." ) - request = job_service.CreateBatchPredictionJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateBatchPredictionJobRequest): + request = job_service.CreateBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3014,11 +3044,9 @@ async def sample_create_batch_prediction_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_batch_prediction_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3107,8 +3135,8 @@ async def sample_get_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3116,7 +3144,10 @@ async def sample_get_batch_prediction_job(): "the individual field arguments should be set." ) - request = job_service.GetBatchPredictionJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetBatchPredictionJobRequest): + request = job_service.GetBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3125,11 +3156,9 @@ async def sample_get_batch_prediction_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_batch_prediction_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3219,8 +3248,8 @@ async def sample_list_batch_prediction_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3228,7 +3257,10 @@ async def sample_list_batch_prediction_jobs(): "the individual field arguments should be set." ) - request = job_service.ListBatchPredictionJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListBatchPredictionJobsRequest): + request = job_service.ListBatchPredictionJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3237,11 +3269,9 @@ async def sample_list_batch_prediction_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_batch_prediction_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_batch_prediction_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3351,8 +3381,8 @@ async def sample_delete_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3360,7 +3390,10 @@ async def sample_delete_batch_prediction_job(): "the individual field arguments should be set." ) - request = job_service.DeleteBatchPredictionJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteBatchPredictionJobRequest): + request = job_service.DeleteBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3369,11 +3402,9 @@ async def sample_delete_batch_prediction_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_batch_prediction_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3469,8 +3500,8 @@ async def sample_cancel_batch_prediction_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3478,7 +3509,10 @@ async def sample_cancel_batch_prediction_job(): "the individual field arguments should be set." ) - request = job_service.CancelBatchPredictionJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelBatchPredictionJobRequest): + request = job_service.CancelBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3487,11 +3521,9 @@ async def sample_cancel_batch_prediction_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_batch_prediction_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3593,8 +3625,8 @@ async def sample_create_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: raise ValueError( @@ -3602,7 +3634,12 @@ async def sample_create_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.CreateModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.CreateModelDeploymentMonitoringJobRequest + ): + request = job_service.CreateModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3613,11 +3650,9 @@ async def sample_create_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_model_deployment_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3719,8 +3754,8 @@ async def sample_search_model_deployment_monitoring_stats_anomalies(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: raise ValueError( @@ -3728,9 +3763,14 @@ async def sample_search_model_deployment_monitoring_stats_anomalies(): "the individual field arguments should be set." ) - request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest + ): + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3741,11 +3781,9 @@ async def sample_search_model_deployment_monitoring_stats_anomalies(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_model_deployment_monitoring_stats_anomalies, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_model_deployment_monitoring_stats_anomalies + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3850,8 +3888,8 @@ async def sample_get_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3859,7 +3897,10 @@ async def sample_get_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.GetModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetModelDeploymentMonitoringJobRequest): + request = job_service.GetModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3868,11 +3909,9 @@ async def sample_get_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_deployment_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3962,8 +4001,8 @@ async def sample_list_model_deployment_monitoring_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3971,7 +4010,12 @@ async def sample_list_model_deployment_monitoring_jobs(): "the individual field arguments should be set." ) - request = job_service.ListModelDeploymentMonitoringJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.ListModelDeploymentMonitoringJobsRequest + ): + request = job_service.ListModelDeploymentMonitoringJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3980,11 +4024,9 @@ async def sample_list_model_deployment_monitoring_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_deployment_monitoring_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_deployment_monitoring_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4126,8 +4168,8 @@ async def sample_update_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -4135,7 +4177,12 @@ async def sample_update_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.UpdateModelDeploymentMonitoringJobRequest + ): + request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4146,11 +4193,9 @@ async def sample_update_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_model_deployment_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4265,8 +4310,8 @@ async def sample_delete_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4274,7 +4319,12 @@ async def sample_delete_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.DeleteModelDeploymentMonitoringJobRequest + ): + request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4283,11 +4333,9 @@ async def sample_delete_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model_deployment_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4375,8 +4423,8 @@ async def sample_pause_model_deployment_monitoring_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4384,7 +4432,12 @@ async def sample_pause_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.PauseModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.PauseModelDeploymentMonitoringJobRequest + ): + request = job_service.PauseModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4393,11 +4446,9 @@ async def sample_pause_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.pause_model_deployment_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.pause_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4473,8 +4524,8 @@ async def sample_resume_model_deployment_monitoring_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4482,7 +4533,12 @@ async def sample_resume_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.ResumeModelDeploymentMonitoringJobRequest + ): + request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4491,11 +4547,9 @@ async def sample_resume_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.resume_model_deployment_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.resume_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index ac78b7bd68..62b5ee0a03 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -901,7 +902,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, JobServiceTransport]] = None, + transport: Optional[ + Union[str, JobServiceTransport, Callable[..., JobServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -913,9 +916,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, JobServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,JobServiceTransport,Callable[..., JobServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the JobServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -1024,8 +1029,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[JobServiceTransport], Callable[..., JobServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., JobServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -1117,8 +1129,8 @@ def sample_create_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1126,10 +1138,8 @@ def sample_create_custom_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateCustomJobRequest): request = job_service.CreateCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1230,8 +1240,8 @@ def sample_get_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1239,10 +1249,8 @@ def sample_get_custom_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetCustomJobRequest): request = job_service.GetCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1340,8 +1348,8 @@ def sample_list_custom_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1349,10 +1357,8 @@ def sample_list_custom_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListCustomJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListCustomJobsRequest): request = job_service.ListCustomJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1469,8 +1475,8 @@ def sample_delete_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1478,10 +1484,8 @@ def sample_delete_custom_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteCustomJobRequest): request = job_service.DeleteCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1585,8 +1589,8 @@ def sample_cancel_custom_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1594,10 +1598,8 @@ def sample_cancel_custom_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelCustomJobRequest): request = job_service.CancelCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1705,8 +1707,8 @@ def sample_create_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1714,10 +1716,8 @@ def sample_create_data_labeling_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateDataLabelingJobRequest): request = job_service.CreateDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1813,8 +1813,8 @@ def sample_get_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1822,10 +1822,8 @@ def sample_get_data_labeling_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetDataLabelingJobRequest): request = job_service.GetDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1922,8 +1920,8 @@ def sample_list_data_labeling_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1931,10 +1929,8 @@ def sample_list_data_labeling_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListDataLabelingJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListDataLabelingJobsRequest): request = job_service.ListDataLabelingJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2051,8 +2047,8 @@ def sample_delete_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2060,10 +2056,8 @@ def sample_delete_data_labeling_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteDataLabelingJobRequest): request = job_service.DeleteDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2156,8 +2150,8 @@ def sample_cancel_data_labeling_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2165,10 +2159,8 @@ def sample_cancel_data_labeling_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelDataLabelingJobRequest): request = job_service.CancelDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2285,8 +2277,8 @@ def sample_create_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: raise ValueError( @@ -2294,10 +2286,8 @@ def sample_create_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateHyperparameterTuningJobRequest): request = job_service.CreateHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2399,8 +2389,8 @@ def sample_get_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2408,10 +2398,8 @@ def sample_get_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetHyperparameterTuningJobRequest): request = job_service.GetHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2513,8 +2501,8 @@ def sample_list_hyperparameter_tuning_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2522,10 +2510,8 @@ def sample_list_hyperparameter_tuning_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListHyperparameterTuningJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListHyperparameterTuningJobsRequest): request = job_service.ListHyperparameterTuningJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2646,8 +2632,8 @@ def sample_delete_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2655,10 +2641,8 @@ def sample_delete_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteHyperparameterTuningJobRequest): request = job_service.DeleteHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2768,8 +2752,8 @@ def sample_cancel_hyperparameter_tuning_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2777,10 +2761,8 @@ def sample_cancel_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelHyperparameterTuningJobRequest): request = job_service.CancelHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2886,8 +2868,8 @@ def sample_create_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, nas_job]) if request is not None and has_flattened_params: raise ValueError( @@ -2895,10 +2877,8 @@ def sample_create_nas_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateNasJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateNasJobRequest): request = job_service.CreateNasJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2993,8 +2973,8 @@ def sample_get_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3002,10 +2982,8 @@ def sample_get_nas_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetNasJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetNasJobRequest): request = job_service.GetNasJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3103,8 +3081,8 @@ def sample_list_nas_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3112,10 +3090,8 @@ def sample_list_nas_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListNasJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListNasJobsRequest): request = job_service.ListNasJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3232,8 +3208,8 @@ def sample_delete_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3241,10 +3217,8 @@ def sample_delete_nas_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteNasJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteNasJobRequest): request = job_service.DeleteNasJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3348,8 +3322,8 @@ def sample_cancel_nas_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3357,10 +3331,8 @@ def sample_cancel_nas_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelNasJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelNasJobRequest): request = job_service.CancelNasJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3453,8 +3425,8 @@ def sample_get_nas_trial_detail(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3462,10 +3434,8 @@ def sample_get_nas_trial_detail(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetNasTrialDetailRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetNasTrialDetailRequest): request = job_service.GetNasTrialDetailRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3562,8 +3532,8 @@ def sample_list_nas_trial_details(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3571,10 +3541,8 @@ def sample_list_nas_trial_details(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListNasTrialDetailsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListNasTrialDetailsRequest): request = job_service.ListNasTrialDetailsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3703,8 +3671,8 @@ def sample_create_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: raise ValueError( @@ -3712,10 +3680,8 @@ def sample_create_batch_prediction_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateBatchPredictionJobRequest): request = job_service.CreateBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3818,8 +3784,8 @@ def sample_get_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3827,10 +3793,8 @@ def sample_get_batch_prediction_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetBatchPredictionJobRequest): request = job_service.GetBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3930,8 +3894,8 @@ def sample_list_batch_prediction_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3939,10 +3903,8 @@ def sample_list_batch_prediction_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListBatchPredictionJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListBatchPredictionJobsRequest): request = job_service.ListBatchPredictionJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4064,8 +4026,8 @@ def sample_delete_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4073,10 +4035,8 @@ def sample_delete_batch_prediction_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteBatchPredictionJobRequest): request = job_service.DeleteBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4184,8 +4144,8 @@ def sample_cancel_batch_prediction_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4193,10 +4153,8 @@ def sample_cancel_batch_prediction_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelBatchPredictionJobRequest): request = job_service.CancelBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4310,8 +4268,8 @@ def sample_create_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: raise ValueError( @@ -4319,10 +4277,8 @@ def sample_create_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.CreateModelDeploymentMonitoringJobRequest ): @@ -4442,8 +4398,8 @@ def sample_search_model_deployment_monitoring_stats_anomalies(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: raise ValueError( @@ -4451,10 +4407,8 @@ def sample_search_model_deployment_monitoring_stats_anomalies(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest ): @@ -4579,8 +4533,8 @@ def sample_get_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4588,10 +4542,8 @@ def sample_get_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetModelDeploymentMonitoringJobRequest): request = job_service.GetModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4693,8 +4645,8 @@ def sample_list_model_deployment_monitoring_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -4702,10 +4654,8 @@ def sample_list_model_deployment_monitoring_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListModelDeploymentMonitoringJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.ListModelDeploymentMonitoringJobsRequest ): @@ -4861,8 +4811,8 @@ def sample_update_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -4870,10 +4820,8 @@ def sample_update_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.UpdateModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.UpdateModelDeploymentMonitoringJobRequest ): @@ -5006,8 +4954,8 @@ def sample_delete_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -5015,10 +4963,8 @@ def sample_delete_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.DeleteModelDeploymentMonitoringJobRequest ): @@ -5120,8 +5066,8 @@ def sample_pause_model_deployment_monitoring_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -5129,10 +5075,8 @@ def sample_pause_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.PauseModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.PauseModelDeploymentMonitoringJobRequest ): @@ -5222,8 +5166,8 @@ def sample_resume_model_deployment_monitoring_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -5231,10 +5175,8 @@ def sample_resume_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ResumeModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.ResumeModelDeploymentMonitoringJobRequest ): diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py index 0869393579..d529402fc8 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py @@ -74,7 +74,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -94,14 +94,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -111,11 +114,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -142,7 +145,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -183,7 +186,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py index 71f45b994d..500a1a6c19 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -89,7 +91,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -119,7 +120,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -139,15 +140,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -157,11 +161,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -188,7 +192,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -228,7 +232,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -1361,6 +1367,186 @@ def resume_model_deployment_monitoring_job( ) return self._stubs["resume_model_deployment_monitoring_job"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_custom_job: gapic_v1.method_async.wrap_method( + self.create_custom_job, + default_timeout=None, + client_info=client_info, + ), + self.get_custom_job: gapic_v1.method_async.wrap_method( + self.get_custom_job, + default_timeout=None, + client_info=client_info, + ), + self.list_custom_jobs: gapic_v1.method_async.wrap_method( + self.list_custom_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_custom_job: gapic_v1.method_async.wrap_method( + self.delete_custom_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_custom_job: gapic_v1.method_async.wrap_method( + self.cancel_custom_job, + default_timeout=None, + client_info=client_info, + ), + self.create_data_labeling_job: gapic_v1.method_async.wrap_method( + self.create_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.get_data_labeling_job: gapic_v1.method_async.wrap_method( + self.get_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.list_data_labeling_jobs: gapic_v1.method_async.wrap_method( + self.list_data_labeling_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_data_labeling_job: gapic_v1.method_async.wrap_method( + self.delete_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_data_labeling_job: gapic_v1.method_async.wrap_method( + self.cancel_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.create_hyperparameter_tuning_job: gapic_v1.method_async.wrap_method( + self.create_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.get_hyperparameter_tuning_job: gapic_v1.method_async.wrap_method( + self.get_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.list_hyperparameter_tuning_jobs: gapic_v1.method_async.wrap_method( + self.list_hyperparameter_tuning_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_hyperparameter_tuning_job: gapic_v1.method_async.wrap_method( + self.delete_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_hyperparameter_tuning_job: gapic_v1.method_async.wrap_method( + self.cancel_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.create_nas_job: gapic_v1.method_async.wrap_method( + self.create_nas_job, + default_timeout=None, + client_info=client_info, + ), + self.get_nas_job: gapic_v1.method_async.wrap_method( + self.get_nas_job, + default_timeout=None, + client_info=client_info, + ), + self.list_nas_jobs: gapic_v1.method_async.wrap_method( + self.list_nas_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_nas_job: gapic_v1.method_async.wrap_method( + self.delete_nas_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_nas_job: gapic_v1.method_async.wrap_method( + self.cancel_nas_job, + default_timeout=None, + client_info=client_info, + ), + self.get_nas_trial_detail: gapic_v1.method_async.wrap_method( + self.get_nas_trial_detail, + default_timeout=None, + client_info=client_info, + ), + self.list_nas_trial_details: gapic_v1.method_async.wrap_method( + self.list_nas_trial_details, + default_timeout=None, + client_info=client_info, + ), + self.create_batch_prediction_job: gapic_v1.method_async.wrap_method( + self.create_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.get_batch_prediction_job: gapic_v1.method_async.wrap_method( + self.get_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.list_batch_prediction_jobs: gapic_v1.method_async.wrap_method( + self.list_batch_prediction_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_batch_prediction_job: gapic_v1.method_async.wrap_method( + self.delete_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_batch_prediction_job: gapic_v1.method_async.wrap_method( + self.cancel_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.create_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.create_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.search_model_deployment_monitoring_stats_anomalies: gapic_v1.method_async.wrap_method( + self.search_model_deployment_monitoring_stats_anomalies, + default_timeout=None, + client_info=client_info, + ), + self.get_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.get_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.list_model_deployment_monitoring_jobs: gapic_v1.method_async.wrap_method( + self.list_model_deployment_monitoring_jobs, + default_timeout=None, + client_info=client_info, + ), + self.update_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.update_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.delete_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.delete_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.pause_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.pause_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.resume_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.resume_model_deployment_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/llm_utility_service/async_client.py b/google/cloud/aiplatform_v1/services/llm_utility_service/async_client.py index 19782868c3..7b8c862a6d 100644 --- a/google/cloud/aiplatform_v1/services/llm_utility_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/llm_utility_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -198,7 +199,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, LlmUtilityServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + LlmUtilityServiceTransport, + Callable[..., LlmUtilityServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -210,9 +217,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.LlmUtilityServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,LlmUtilityServiceTransport,Callable[..., LlmUtilityServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the LlmUtilityServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -337,8 +346,8 @@ async def sample_count_tokens(): Response message for [PredictionService.CountTokens][]. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances]) if request is not None and has_flattened_params: raise ValueError( @@ -346,7 +355,10 @@ async def sample_count_tokens(): "the individual field arguments should be set." ) - request = prediction_service.CountTokensRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.CountTokensRequest): + request = prediction_service.CountTokensRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -357,11 +369,9 @@ async def sample_count_tokens(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.count_tokens, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.count_tokens + ] # Certain fields should be provided within the metadata header; # add these here. @@ -461,8 +471,8 @@ async def sample_compute_tokens(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances]) if request is not None and has_flattened_params: raise ValueError( @@ -470,7 +480,10 @@ async def sample_compute_tokens(): "the individual field arguments should be set." ) - request = llm_utility_service.ComputeTokensRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, llm_utility_service.ComputeTokensRequest): + request = llm_utility_service.ComputeTokensRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -481,11 +494,9 @@ async def sample_compute_tokens(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.compute_tokens, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.compute_tokens + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/llm_utility_service/client.py b/google/cloud/aiplatform_v1/services/llm_utility_service/client.py index eb5da26559..17ccf7fdb5 100644 --- a/google/cloud/aiplatform_v1/services/llm_utility_service/client.py +++ b/google/cloud/aiplatform_v1/services/llm_utility_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -533,7 +534,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, LlmUtilityServiceTransport]] = None, + transport: Optional[ + Union[ + str, + LlmUtilityServiceTransport, + Callable[..., LlmUtilityServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -545,9 +552,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, LlmUtilityServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,LlmUtilityServiceTransport,Callable[..., LlmUtilityServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the LlmUtilityServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -659,8 +668,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[LlmUtilityServiceTransport], + Callable[..., LlmUtilityServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., LlmUtilityServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -750,8 +767,8 @@ def sample_count_tokens(): Response message for [PredictionService.CountTokens][]. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances]) if request is not None and has_flattened_params: raise ValueError( @@ -759,10 +776,8 @@ def sample_count_tokens(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.CountTokensRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.CountTokensRequest): request = prediction_service.CountTokensRequest(request) # If we have keyword arguments corresponding to fields on the @@ -874,8 +889,8 @@ def sample_compute_tokens(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances]) if request is not None and has_flattened_params: raise ValueError( @@ -883,10 +898,8 @@ def sample_compute_tokens(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a llm_utility_service.ComputeTokensRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, llm_utility_service.ComputeTokensRequest): request = llm_utility_service.ComputeTokensRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/llm_utility_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/llm_utility_service/transports/grpc.py index 6405833b44..c8554db89e 100644 --- a/google/cloud/aiplatform_v1/services/llm_utility_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/llm_utility_service/transports/grpc.py @@ -55,7 +55,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -75,14 +75,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -92,11 +95,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -122,7 +125,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -163,7 +166,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/llm_utility_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/llm_utility_service/transports/grpc_asyncio.py index 614a8d53e1..9bc3529b46 100644 --- a/google/cloud/aiplatform_v1/services/llm_utility_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/llm_utility_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -70,7 +72,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -100,7 +101,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -120,15 +121,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -138,11 +142,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -168,7 +172,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -208,7 +212,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -295,6 +301,21 @@ def compute_tokens( ) return self._stubs["compute_tokens"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.count_tokens: gapic_v1.method_async.wrap_method( + self.count_tokens, + default_timeout=None, + client_info=client_info, + ), + self.compute_tokens: gapic_v1.method_async.wrap_method( + self.compute_tokens, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/match_service/async_client.py b/google/cloud/aiplatform_v1/services/match_service/async_client.py index 34496f9a8c..c430407600 100644 --- a/google/cloud/aiplatform_v1/services/match_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/match_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -197,7 +198,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, MatchServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, MatchServiceTransport, Callable[..., MatchServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -209,9 +212,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.MatchServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MatchServiceTransport,Callable[..., MatchServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MatchServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -312,15 +317,16 @@ async def sample_find_neighbors(): """ # Create or coerce a protobuf request object. - request = match_service.FindNeighborsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, match_service.FindNeighborsRequest): + request = match_service.FindNeighborsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.find_neighbors, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.find_neighbors + ] # Certain fields should be provided within the metadata header; # add these here. @@ -399,15 +405,16 @@ async def sample_read_index_datapoints(): """ # Create or coerce a protobuf request object. - request = match_service.ReadIndexDatapointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, match_service.ReadIndexDatapointsRequest): + request = match_service.ReadIndexDatapointsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_index_datapoints, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_index_datapoints + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/match_service/client.py b/google/cloud/aiplatform_v1/services/match_service/client.py index 4fe0835624..9c94004d68 100644 --- a/google/cloud/aiplatform_v1/services/match_service/client.py +++ b/google/cloud/aiplatform_v1/services/match_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -532,7 +533,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, MatchServiceTransport]] = None, + transport: Optional[ + Union[str, MatchServiceTransport, Callable[..., MatchServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -544,9 +547,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, MatchServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MatchServiceTransport,Callable[..., MatchServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MatchServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -655,8 +660,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[MatchServiceTransport], Callable[..., MatchServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., MatchServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -722,10 +734,8 @@ def sample_find_neighbors(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a match_service.FindNeighborsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, match_service.FindNeighborsRequest): request = match_service.FindNeighborsRequest(request) @@ -810,10 +820,8 @@ def sample_read_index_datapoints(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a match_service.ReadIndexDatapointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, match_service.ReadIndexDatapointsRequest): request = match_service.ReadIndexDatapointsRequest(request) diff --git a/google/cloud/aiplatform_v1/services/match_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/match_service/transports/grpc.py index cf3f50bd96..8d15cadf46 100644 --- a/google/cloud/aiplatform_v1/services/match_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/match_service/transports/grpc.py @@ -55,7 +55,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -75,14 +75,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -92,11 +95,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -122,7 +125,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -163,7 +166,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/match_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/match_service/transports/grpc_asyncio.py index f9fefc7685..b90b2f131e 100644 --- a/google/cloud/aiplatform_v1/services/match_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/match_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -70,7 +72,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -100,7 +101,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -120,15 +121,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -138,11 +142,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -168,7 +172,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -208,7 +212,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -298,6 +304,21 @@ def read_index_datapoints( ) return self._stubs["read_index_datapoints"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.find_neighbors: gapic_v1.method_async.wrap_method( + self.find_neighbors, + default_timeout=None, + client_info=client_info, + ), + self.read_index_datapoints: gapic_v1.method_async.wrap_method( + self.read_index_datapoints, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1/services/metadata_service/async_client.py index 1afbd8c080..a777b013c7 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -229,7 +230,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, MetadataServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, MetadataServiceTransport, Callable[..., MetadataServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -241,9 +246,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.MetadataServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MetadataServiceTransport,Callable[..., MetadataServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MetadataServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -385,8 +392,8 @@ async def sample_create_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: raise ValueError( @@ -394,7 +401,10 @@ async def sample_create_metadata_store(): "the individual field arguments should be set." ) - request = metadata_service.CreateMetadataStoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateMetadataStoreRequest): + request = metadata_service.CreateMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -407,11 +417,9 @@ async def sample_create_metadata_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_metadata_store, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_metadata_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -504,8 +512,8 @@ async def sample_get_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -513,7 +521,10 @@ async def sample_get_metadata_store(): "the individual field arguments should be set." ) - request = metadata_service.GetMetadataStoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetMetadataStoreRequest): + request = metadata_service.GetMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -522,11 +533,9 @@ async def sample_get_metadata_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_metadata_store, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_metadata_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -616,8 +625,8 @@ async def sample_list_metadata_stores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -625,7 +634,10 @@ async def sample_list_metadata_stores(): "the individual field arguments should be set." ) - request = metadata_service.ListMetadataStoresRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListMetadataStoresRequest): + request = metadata_service.ListMetadataStoresRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -634,11 +646,9 @@ async def sample_list_metadata_stores(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_metadata_stores, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_metadata_stores + ] # Certain fields should be provided within the metadata header; # add these here. @@ -748,8 +758,8 @@ async def sample_delete_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -757,7 +767,10 @@ async def sample_delete_metadata_store(): "the individual field arguments should be set." ) - request = metadata_service.DeleteMetadataStoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.DeleteMetadataStoreRequest): + request = metadata_service.DeleteMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -766,11 +779,9 @@ async def sample_delete_metadata_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_metadata_store, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_metadata_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -882,8 +893,8 @@ async def sample_create_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: raise ValueError( @@ -891,7 +902,10 @@ async def sample_create_artifact(): "the individual field arguments should be set." ) - request = metadata_service.CreateArtifactRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateArtifactRequest): + request = metadata_service.CreateArtifactRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -904,11 +918,9 @@ async def sample_create_artifact(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_artifact, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_artifact + ] # Certain fields should be provided within the metadata header; # add these here. @@ -990,8 +1002,8 @@ async def sample_get_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -999,7 +1011,10 @@ async def sample_get_artifact(): "the individual field arguments should be set." ) - request = metadata_service.GetArtifactRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetArtifactRequest): + request = metadata_service.GetArtifactRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1008,11 +1023,9 @@ async def sample_get_artifact(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_artifact, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_artifact + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1100,8 +1113,8 @@ async def sample_list_artifacts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1109,7 +1122,10 @@ async def sample_list_artifacts(): "the individual field arguments should be set." ) - request = metadata_service.ListArtifactsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListArtifactsRequest): + request = metadata_service.ListArtifactsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1118,11 +1134,9 @@ async def sample_list_artifacts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_artifacts, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_artifacts + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1223,8 +1237,8 @@ async def sample_update_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1232,7 +1246,10 @@ async def sample_update_artifact(): "the individual field arguments should be set." ) - request = metadata_service.UpdateArtifactRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.UpdateArtifactRequest): + request = metadata_service.UpdateArtifactRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1243,11 +1260,9 @@ async def sample_update_artifact(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_artifact, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_artifact + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1347,8 +1362,8 @@ async def sample_delete_artifact(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1356,7 +1371,10 @@ async def sample_delete_artifact(): "the individual field arguments should be set." ) - request = metadata_service.DeleteArtifactRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.DeleteArtifactRequest): + request = metadata_service.DeleteArtifactRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1365,11 +1383,9 @@ async def sample_delete_artifact(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_artifact, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_artifact + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1468,8 +1484,8 @@ async def sample_purge_artifacts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1477,7 +1493,10 @@ async def sample_purge_artifacts(): "the individual field arguments should be set." ) - request = metadata_service.PurgeArtifactsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.PurgeArtifactsRequest): + request = metadata_service.PurgeArtifactsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1486,11 +1505,9 @@ async def sample_purge_artifacts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.purge_artifacts, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.purge_artifacts + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1602,8 +1619,8 @@ async def sample_create_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1611,7 +1628,10 @@ async def sample_create_context(): "the individual field arguments should be set." ) - request = metadata_service.CreateContextRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateContextRequest): + request = metadata_service.CreateContextRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1624,11 +1644,9 @@ async def sample_create_context(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_context, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_context + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1710,8 +1728,8 @@ async def sample_get_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1719,7 +1737,10 @@ async def sample_get_context(): "the individual field arguments should be set." ) - request = metadata_service.GetContextRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetContextRequest): + request = metadata_service.GetContextRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1728,11 +1749,9 @@ async def sample_get_context(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_context, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_context + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1820,8 +1839,8 @@ async def sample_list_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1829,7 +1848,10 @@ async def sample_list_contexts(): "the individual field arguments should be set." ) - request = metadata_service.ListContextsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListContextsRequest): + request = metadata_service.ListContextsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1838,11 +1860,9 @@ async def sample_list_contexts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_contexts, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_contexts + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1942,8 +1962,8 @@ async def sample_update_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1951,7 +1971,10 @@ async def sample_update_context(): "the individual field arguments should be set." ) - request = metadata_service.UpdateContextRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.UpdateContextRequest): + request = metadata_service.UpdateContextRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1962,11 +1985,9 @@ async def sample_update_context(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_context, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_context + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2066,8 +2087,8 @@ async def sample_delete_context(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2075,7 +2096,10 @@ async def sample_delete_context(): "the individual field arguments should be set." ) - request = metadata_service.DeleteContextRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.DeleteContextRequest): + request = metadata_service.DeleteContextRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2084,11 +2108,9 @@ async def sample_delete_context(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_context, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_context + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2187,8 +2209,8 @@ async def sample_purge_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2196,7 +2218,10 @@ async def sample_purge_contexts(): "the individual field arguments should be set." ) - request = metadata_service.PurgeContextsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.PurgeContextsRequest): + request = metadata_service.PurgeContextsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2205,11 +2230,9 @@ async def sample_purge_contexts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.purge_contexts, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.purge_contexts + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2327,8 +2350,8 @@ async def sample_add_context_artifacts_and_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: raise ValueError( @@ -2336,7 +2359,12 @@ async def sample_add_context_artifacts_and_executions(): "the individual field arguments should be set." ) - request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, metadata_service.AddContextArtifactsAndExecutionsRequest + ): + request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2349,11 +2377,9 @@ async def sample_add_context_artifacts_and_executions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.add_context_artifacts_and_executions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.add_context_artifacts_and_executions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2452,8 +2478,8 @@ async def sample_add_context_children(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: raise ValueError( @@ -2461,7 +2487,10 @@ async def sample_add_context_children(): "the individual field arguments should be set." ) - request = metadata_service.AddContextChildrenRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.AddContextChildrenRequest): + request = metadata_service.AddContextChildrenRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2472,11 +2501,9 @@ async def sample_add_context_children(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.add_context_children, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.add_context_children + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2573,8 +2600,8 @@ async def sample_remove_context_children(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: raise ValueError( @@ -2582,7 +2609,10 @@ async def sample_remove_context_children(): "the individual field arguments should be set." ) - request = metadata_service.RemoveContextChildrenRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.RemoveContextChildrenRequest): + request = metadata_service.RemoveContextChildrenRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2593,11 +2623,9 @@ async def sample_remove_context_children(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.remove_context_children, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.remove_context_children + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2692,8 +2720,8 @@ async def sample_query_context_lineage_subgraph(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: raise ValueError( @@ -2701,7 +2729,10 @@ async def sample_query_context_lineage_subgraph(): "the individual field arguments should be set." ) - request = metadata_service.QueryContextLineageSubgraphRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.QueryContextLineageSubgraphRequest): + request = metadata_service.QueryContextLineageSubgraphRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2710,11 +2741,9 @@ async def sample_query_context_lineage_subgraph(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_context_lineage_subgraph, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_context_lineage_subgraph + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2818,8 +2847,8 @@ async def sample_create_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: raise ValueError( @@ -2827,7 +2856,10 @@ async def sample_create_execution(): "the individual field arguments should be set." ) - request = metadata_service.CreateExecutionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateExecutionRequest): + request = metadata_service.CreateExecutionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2840,11 +2872,9 @@ async def sample_create_execution(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_execution, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_execution + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2926,8 +2956,8 @@ async def sample_get_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2935,7 +2965,10 @@ async def sample_get_execution(): "the individual field arguments should be set." ) - request = metadata_service.GetExecutionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetExecutionRequest): + request = metadata_service.GetExecutionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2944,11 +2977,9 @@ async def sample_get_execution(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_execution, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_execution + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3036,8 +3067,8 @@ async def sample_list_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3045,7 +3076,10 @@ async def sample_list_executions(): "the individual field arguments should be set." ) - request = metadata_service.ListExecutionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListExecutionsRequest): + request = metadata_service.ListExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3054,11 +3088,9 @@ async def sample_list_executions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_executions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_executions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3159,8 +3191,8 @@ async def sample_update_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -3168,7 +3200,10 @@ async def sample_update_execution(): "the individual field arguments should be set." ) - request = metadata_service.UpdateExecutionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.UpdateExecutionRequest): + request = metadata_service.UpdateExecutionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3179,11 +3214,9 @@ async def sample_update_execution(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_execution, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_execution + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3283,8 +3316,8 @@ async def sample_delete_execution(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3292,7 +3325,10 @@ async def sample_delete_execution(): "the individual field arguments should be set." ) - request = metadata_service.DeleteExecutionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.DeleteExecutionRequest): + request = metadata_service.DeleteExecutionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3301,11 +3337,9 @@ async def sample_delete_execution(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_execution, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_execution + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3404,8 +3438,8 @@ async def sample_purge_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3413,7 +3447,10 @@ async def sample_purge_executions(): "the individual field arguments should be set." ) - request = metadata_service.PurgeExecutionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.PurgeExecutionsRequest): + request = metadata_service.PurgeExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3422,11 +3459,9 @@ async def sample_purge_executions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.purge_executions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.purge_executions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3530,8 +3565,8 @@ async def sample_add_execution_events(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: raise ValueError( @@ -3539,7 +3574,10 @@ async def sample_add_execution_events(): "the individual field arguments should be set." ) - request = metadata_service.AddExecutionEventsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.AddExecutionEventsRequest): + request = metadata_service.AddExecutionEventsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3550,11 +3588,9 @@ async def sample_add_execution_events(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.add_execution_events, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.add_execution_events + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3646,8 +3682,8 @@ async def sample_query_execution_inputs_and_outputs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: raise ValueError( @@ -3655,7 +3691,12 @@ async def sample_query_execution_inputs_and_outputs(): "the individual field arguments should be set." ) - request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, metadata_service.QueryExecutionInputsAndOutputsRequest + ): + request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3664,11 +3705,9 @@ async def sample_query_execution_inputs_and_outputs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_execution_inputs_and_outputs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_execution_inputs_and_outputs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3782,8 +3821,8 @@ async def sample_create_metadata_schema(): Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: raise ValueError( @@ -3791,7 +3830,10 @@ async def sample_create_metadata_schema(): "the individual field arguments should be set." ) - request = metadata_service.CreateMetadataSchemaRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateMetadataSchemaRequest): + request = metadata_service.CreateMetadataSchemaRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3804,11 +3846,9 @@ async def sample_create_metadata_schema(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_metadata_schema, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_metadata_schema + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3892,8 +3932,8 @@ async def sample_get_metadata_schema(): Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3901,7 +3941,10 @@ async def sample_get_metadata_schema(): "the individual field arguments should be set." ) - request = metadata_service.GetMetadataSchemaRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetMetadataSchemaRequest): + request = metadata_service.GetMetadataSchemaRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3910,11 +3953,9 @@ async def sample_get_metadata_schema(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_metadata_schema, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_metadata_schema + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4004,8 +4045,8 @@ async def sample_list_metadata_schemas(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -4013,7 +4054,10 @@ async def sample_list_metadata_schemas(): "the individual field arguments should be set." ) - request = metadata_service.ListMetadataSchemasRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListMetadataSchemasRequest): + request = metadata_service.ListMetadataSchemasRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4022,11 +4066,9 @@ async def sample_list_metadata_schemas(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_metadata_schemas, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_metadata_schemas + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4130,8 +4172,8 @@ async def sample_query_artifact_lineage_subgraph(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: raise ValueError( @@ -4139,7 +4181,12 @@ async def sample_query_artifact_lineage_subgraph(): "the individual field arguments should be set." ) - request = metadata_service.QueryArtifactLineageSubgraphRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, metadata_service.QueryArtifactLineageSubgraphRequest + ): + request = metadata_service.QueryArtifactLineageSubgraphRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4148,11 +4195,9 @@ async def sample_query_artifact_lineage_subgraph(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_artifact_lineage_subgraph, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_artifact_lineage_subgraph + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/metadata_service/client.py b/google/cloud/aiplatform_v1/services/metadata_service/client.py index fde73daf99..7bce6ccfd7 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -648,7 +649,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, MetadataServiceTransport]] = None, + transport: Optional[ + Union[ + str, MetadataServiceTransport, Callable[..., MetadataServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -660,9 +665,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, MetadataServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MetadataServiceTransport,Callable[..., MetadataServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MetadataServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -774,8 +781,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[MetadataServiceTransport], Callable[..., MetadataServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., MetadataServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -882,8 +896,8 @@ def sample_create_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: raise ValueError( @@ -891,10 +905,8 @@ def sample_create_metadata_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateMetadataStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateMetadataStoreRequest): request = metadata_service.CreateMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1001,8 +1013,8 @@ def sample_get_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1010,10 +1022,8 @@ def sample_get_metadata_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetMetadataStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetMetadataStoreRequest): request = metadata_service.GetMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1113,8 +1123,8 @@ def sample_list_metadata_stores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1122,10 +1132,8 @@ def sample_list_metadata_stores(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListMetadataStoresRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListMetadataStoresRequest): request = metadata_service.ListMetadataStoresRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1245,8 +1253,8 @@ def sample_delete_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1254,10 +1262,8 @@ def sample_delete_metadata_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.DeleteMetadataStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.DeleteMetadataStoreRequest): request = metadata_service.DeleteMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1379,8 +1385,8 @@ def sample_create_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1388,10 +1394,8 @@ def sample_create_artifact(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateArtifactRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateArtifactRequest): request = metadata_service.CreateArtifactRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1487,8 +1491,8 @@ def sample_get_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1496,10 +1500,8 @@ def sample_get_artifact(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetArtifactRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetArtifactRequest): request = metadata_service.GetArtifactRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1597,8 +1599,8 @@ def sample_list_artifacts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1606,10 +1608,8 @@ def sample_list_artifacts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListArtifactsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListArtifactsRequest): request = metadata_service.ListArtifactsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1720,8 +1720,8 @@ def sample_update_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1729,10 +1729,8 @@ def sample_update_artifact(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.UpdateArtifactRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.UpdateArtifactRequest): request = metadata_service.UpdateArtifactRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1844,8 +1842,8 @@ def sample_delete_artifact(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1853,10 +1851,8 @@ def sample_delete_artifact(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.DeleteArtifactRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.DeleteArtifactRequest): request = metadata_service.DeleteArtifactRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1965,8 +1961,8 @@ def sample_purge_artifacts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1974,10 +1970,8 @@ def sample_purge_artifacts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.PurgeArtifactsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.PurgeArtifactsRequest): request = metadata_service.PurgeArtifactsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2099,8 +2093,8 @@ def sample_create_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: raise ValueError( @@ -2108,10 +2102,8 @@ def sample_create_context(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateContextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateContextRequest): request = metadata_service.CreateContextRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2207,8 +2199,8 @@ def sample_get_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2216,10 +2208,8 @@ def sample_get_context(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetContextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetContextRequest): request = metadata_service.GetContextRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2317,8 +2307,8 @@ def sample_list_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2326,10 +2316,8 @@ def sample_list_contexts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListContextsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListContextsRequest): request = metadata_service.ListContextsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2439,8 +2427,8 @@ def sample_update_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2448,10 +2436,8 @@ def sample_update_context(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.UpdateContextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.UpdateContextRequest): request = metadata_service.UpdateContextRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2563,8 +2549,8 @@ def sample_delete_context(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2572,10 +2558,8 @@ def sample_delete_context(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.DeleteContextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.DeleteContextRequest): request = metadata_service.DeleteContextRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2684,8 +2668,8 @@ def sample_purge_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2693,10 +2677,8 @@ def sample_purge_contexts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.PurgeContextsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.PurgeContextsRequest): request = metadata_service.PurgeContextsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2824,8 +2806,8 @@ def sample_add_context_artifacts_and_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: raise ValueError( @@ -2833,10 +2815,8 @@ def sample_add_context_artifacts_and_executions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.AddContextArtifactsAndExecutionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, metadata_service.AddContextArtifactsAndExecutionsRequest ): @@ -2953,8 +2933,8 @@ def sample_add_context_children(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: raise ValueError( @@ -2962,10 +2942,8 @@ def sample_add_context_children(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.AddContextChildrenRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.AddContextChildrenRequest): request = metadata_service.AddContextChildrenRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3074,8 +3052,8 @@ def sample_remove_context_children(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: raise ValueError( @@ -3083,10 +3061,8 @@ def sample_remove_context_children(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.RemoveContextChildrenRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.RemoveContextChildrenRequest): request = metadata_service.RemoveContextChildrenRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3193,8 +3169,8 @@ def sample_query_context_lineage_subgraph(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: raise ValueError( @@ -3202,10 +3178,8 @@ def sample_query_context_lineage_subgraph(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.QueryContextLineageSubgraphRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.QueryContextLineageSubgraphRequest): request = metadata_service.QueryContextLineageSubgraphRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3321,8 +3295,8 @@ def sample_create_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: raise ValueError( @@ -3330,10 +3304,8 @@ def sample_create_execution(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateExecutionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateExecutionRequest): request = metadata_service.CreateExecutionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3429,8 +3401,8 @@ def sample_get_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3438,10 +3410,8 @@ def sample_get_execution(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetExecutionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetExecutionRequest): request = metadata_service.GetExecutionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3539,8 +3509,8 @@ def sample_list_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3548,10 +3518,8 @@ def sample_list_executions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListExecutionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListExecutionsRequest): request = metadata_service.ListExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3662,8 +3630,8 @@ def sample_update_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -3671,10 +3639,8 @@ def sample_update_execution(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.UpdateExecutionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.UpdateExecutionRequest): request = metadata_service.UpdateExecutionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3786,8 +3752,8 @@ def sample_delete_execution(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3795,10 +3761,8 @@ def sample_delete_execution(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.DeleteExecutionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.DeleteExecutionRequest): request = metadata_service.DeleteExecutionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3907,8 +3871,8 @@ def sample_purge_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3916,10 +3880,8 @@ def sample_purge_executions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.PurgeExecutionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.PurgeExecutionsRequest): request = metadata_service.PurgeExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4033,8 +3995,8 @@ def sample_add_execution_events(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: raise ValueError( @@ -4042,10 +4004,8 @@ def sample_add_execution_events(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.AddExecutionEventsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.AddExecutionEventsRequest): request = metadata_service.AddExecutionEventsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4149,8 +4109,8 @@ def sample_query_execution_inputs_and_outputs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: raise ValueError( @@ -4158,10 +4118,8 @@ def sample_query_execution_inputs_and_outputs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.QueryExecutionInputsAndOutputsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, metadata_service.QueryExecutionInputsAndOutputsRequest ): @@ -4289,8 +4247,8 @@ def sample_create_metadata_schema(): Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: raise ValueError( @@ -4298,10 +4256,8 @@ def sample_create_metadata_schema(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateMetadataSchemaRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateMetadataSchemaRequest): request = metadata_service.CreateMetadataSchemaRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4399,8 +4355,8 @@ def sample_get_metadata_schema(): Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4408,10 +4364,8 @@ def sample_get_metadata_schema(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetMetadataSchemaRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetMetadataSchemaRequest): request = metadata_service.GetMetadataSchemaRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4511,8 +4465,8 @@ def sample_list_metadata_schemas(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -4520,10 +4474,8 @@ def sample_list_metadata_schemas(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListMetadataSchemasRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListMetadataSchemasRequest): request = metadata_service.ListMetadataSchemasRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4637,8 +4589,8 @@ def sample_query_artifact_lineage_subgraph(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: raise ValueError( @@ -4646,10 +4598,8 @@ def sample_query_artifact_lineage_subgraph(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.QueryArtifactLineageSubgraphRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, metadata_service.QueryArtifactLineageSubgraphRequest ): diff --git a/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc.py index 6625516f34..4a0601e2f4 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc.py @@ -65,7 +65,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -85,14 +85,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -102,11 +105,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -133,7 +136,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -174,7 +177,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc_asyncio.py index 205b987364..58761e4cdc 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -80,7 +82,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -110,7 +111,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -130,15 +131,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -148,11 +152,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -179,7 +183,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -219,7 +223,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -1204,6 +1210,171 @@ def query_artifact_lineage_subgraph( ) return self._stubs["query_artifact_lineage_subgraph"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_metadata_store: gapic_v1.method_async.wrap_method( + self.create_metadata_store, + default_timeout=None, + client_info=client_info, + ), + self.get_metadata_store: gapic_v1.method_async.wrap_method( + self.get_metadata_store, + default_timeout=None, + client_info=client_info, + ), + self.list_metadata_stores: gapic_v1.method_async.wrap_method( + self.list_metadata_stores, + default_timeout=None, + client_info=client_info, + ), + self.delete_metadata_store: gapic_v1.method_async.wrap_method( + self.delete_metadata_store, + default_timeout=None, + client_info=client_info, + ), + self.create_artifact: gapic_v1.method_async.wrap_method( + self.create_artifact, + default_timeout=None, + client_info=client_info, + ), + self.get_artifact: gapic_v1.method_async.wrap_method( + self.get_artifact, + default_timeout=None, + client_info=client_info, + ), + self.list_artifacts: gapic_v1.method_async.wrap_method( + self.list_artifacts, + default_timeout=None, + client_info=client_info, + ), + self.update_artifact: gapic_v1.method_async.wrap_method( + self.update_artifact, + default_timeout=None, + client_info=client_info, + ), + self.delete_artifact: gapic_v1.method_async.wrap_method( + self.delete_artifact, + default_timeout=None, + client_info=client_info, + ), + self.purge_artifacts: gapic_v1.method_async.wrap_method( + self.purge_artifacts, + default_timeout=None, + client_info=client_info, + ), + self.create_context: gapic_v1.method_async.wrap_method( + self.create_context, + default_timeout=None, + client_info=client_info, + ), + self.get_context: gapic_v1.method_async.wrap_method( + self.get_context, + default_timeout=None, + client_info=client_info, + ), + self.list_contexts: gapic_v1.method_async.wrap_method( + self.list_contexts, + default_timeout=None, + client_info=client_info, + ), + self.update_context: gapic_v1.method_async.wrap_method( + self.update_context, + default_timeout=None, + client_info=client_info, + ), + self.delete_context: gapic_v1.method_async.wrap_method( + self.delete_context, + default_timeout=None, + client_info=client_info, + ), + self.purge_contexts: gapic_v1.method_async.wrap_method( + self.purge_contexts, + default_timeout=None, + client_info=client_info, + ), + self.add_context_artifacts_and_executions: gapic_v1.method_async.wrap_method( + self.add_context_artifacts_and_executions, + default_timeout=None, + client_info=client_info, + ), + self.add_context_children: gapic_v1.method_async.wrap_method( + self.add_context_children, + default_timeout=None, + client_info=client_info, + ), + self.remove_context_children: gapic_v1.method_async.wrap_method( + self.remove_context_children, + default_timeout=None, + client_info=client_info, + ), + self.query_context_lineage_subgraph: gapic_v1.method_async.wrap_method( + self.query_context_lineage_subgraph, + default_timeout=None, + client_info=client_info, + ), + self.create_execution: gapic_v1.method_async.wrap_method( + self.create_execution, + default_timeout=None, + client_info=client_info, + ), + self.get_execution: gapic_v1.method_async.wrap_method( + self.get_execution, + default_timeout=None, + client_info=client_info, + ), + self.list_executions: gapic_v1.method_async.wrap_method( + self.list_executions, + default_timeout=None, + client_info=client_info, + ), + self.update_execution: gapic_v1.method_async.wrap_method( + self.update_execution, + default_timeout=None, + client_info=client_info, + ), + self.delete_execution: gapic_v1.method_async.wrap_method( + self.delete_execution, + default_timeout=None, + client_info=client_info, + ), + self.purge_executions: gapic_v1.method_async.wrap_method( + self.purge_executions, + default_timeout=None, + client_info=client_info, + ), + self.add_execution_events: gapic_v1.method_async.wrap_method( + self.add_execution_events, + default_timeout=None, + client_info=client_info, + ), + self.query_execution_inputs_and_outputs: gapic_v1.method_async.wrap_method( + self.query_execution_inputs_and_outputs, + default_timeout=None, + client_info=client_info, + ), + self.create_metadata_schema: gapic_v1.method_async.wrap_method( + self.create_metadata_schema, + default_timeout=None, + client_info=client_info, + ), + self.get_metadata_schema: gapic_v1.method_async.wrap_method( + self.get_metadata_schema, + default_timeout=None, + client_info=client_info, + ), + self.list_metadata_schemas: gapic_v1.method_async.wrap_method( + self.list_metadata_schemas, + default_timeout=None, + client_info=client_info, + ), + self.query_artifact_lineage_subgraph: gapic_v1.method_async.wrap_method( + self.query_artifact_lineage_subgraph, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1/services/migration_service/async_client.py index 877cacfc6c..8b52abe4d5 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -216,7 +217,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, MigrationServiceTransport, Callable[..., MigrationServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -228,9 +233,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.MigrationServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MigrationServiceTransport,Callable[..., MigrationServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MigrationServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -350,8 +357,8 @@ async def sample_search_migratable_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -359,7 +366,10 @@ async def sample_search_migratable_resources(): "the individual field arguments should be set." ) - request = migration_service.SearchMigratableResourcesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, migration_service.SearchMigratableResourcesRequest): + request = migration_service.SearchMigratableResourcesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -368,11 +378,9 @@ async def sample_search_migratable_resources(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_migratable_resources, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_migratable_resources + ] # Certain fields should be provided within the metadata header; # add these here. @@ -494,8 +502,8 @@ async def sample_batch_migrate_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: raise ValueError( @@ -503,7 +511,10 @@ async def sample_batch_migrate_resources(): "the individual field arguments should be set." ) - request = migration_service.BatchMigrateResourcesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, migration_service.BatchMigrateResourcesRequest): + request = migration_service.BatchMigrateResourcesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -514,11 +525,9 @@ async def sample_batch_migrate_resources(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_migrate_resources, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_migrate_resources + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 11e08352a5..5e2c922f28 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -215,40 +216,40 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod @@ -664,7 +665,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, MigrationServiceTransport]] = None, + transport: Optional[ + Union[ + str, MigrationServiceTransport, Callable[..., MigrationServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -676,9 +681,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, MigrationServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MigrationServiceTransport,Callable[..., MigrationServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MigrationServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -790,8 +797,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[MigrationServiceTransport], + Callable[..., MigrationServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., MigrationServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -876,8 +891,8 @@ def sample_search_migratable_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -885,10 +900,8 @@ def sample_search_migratable_resources(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a migration_service.SearchMigratableResourcesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, migration_service.SearchMigratableResourcesRequest): request = migration_service.SearchMigratableResourcesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1022,8 +1035,8 @@ def sample_batch_migrate_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: raise ValueError( @@ -1031,10 +1044,8 @@ def sample_batch_migrate_resources(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a migration_service.BatchMigrateResourcesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, migration_service.BatchMigrateResourcesRequest): request = migration_service.BatchMigrateResourcesRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py index 993c4b8f68..d195551f40 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py @@ -56,7 +56,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -76,14 +76,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -93,11 +96,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -124,7 +127,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -165,7 +168,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py index 9d520d60a4..9674bb0682 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -71,7 +73,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -101,7 +102,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -121,15 +122,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -139,11 +143,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -170,7 +174,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -210,7 +214,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -318,6 +324,21 @@ def batch_migrate_resources( ) return self._stubs["batch_migrate_resources"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.search_migratable_resources: gapic_v1.method_async.wrap_method( + self.search_migratable_resources, + default_timeout=None, + client_info=client_info, + ), + self.batch_migrate_resources: gapic_v1.method_async.wrap_method( + self.batch_migrate_resources, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/model_garden_service/async_client.py b/google/cloud/aiplatform_v1/services/model_garden_service/async_client.py index fa3e5c31d1..a1f531e997 100644 --- a/google/cloud/aiplatform_v1/services/model_garden_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_garden_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -201,7 +202,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ModelGardenServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + ModelGardenServiceTransport, + Callable[..., ModelGardenServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -213,9 +220,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ModelGardenServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelGardenServiceTransport,Callable[..., ModelGardenServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelGardenServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -324,8 +333,8 @@ async def sample_get_publisher_model(): A Model Garden Publisher Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -333,7 +342,10 @@ async def sample_get_publisher_model(): "the individual field arguments should be set." ) - request = model_garden_service.GetPublisherModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_garden_service.GetPublisherModelRequest): + request = model_garden_service.GetPublisherModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -342,11 +354,9 @@ async def sample_get_publisher_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_publisher_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_publisher_model + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/model_garden_service/client.py b/google/cloud/aiplatform_v1/services/model_garden_service/client.py index 6b89ee1b17..3892842cfd 100644 --- a/google/cloud/aiplatform_v1/services/model_garden_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_garden_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -528,7 +529,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ModelGardenServiceTransport]] = None, + transport: Optional[ + Union[ + str, + ModelGardenServiceTransport, + Callable[..., ModelGardenServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -540,9 +547,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ModelGardenServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelGardenServiceTransport,Callable[..., ModelGardenServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelGardenServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -654,8 +663,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ModelGardenServiceTransport], + Callable[..., ModelGardenServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ModelGardenServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -729,8 +746,8 @@ def sample_get_publisher_model(): A Model Garden Publisher Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -738,10 +755,8 @@ def sample_get_publisher_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_garden_service.GetPublisherModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_garden_service.GetPublisherModelRequest): request = model_garden_service.GetPublisherModelRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/model_garden_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/model_garden_service/transports/grpc.py index 3fd5a2c2c3..3c46ac9976 100644 --- a/google/cloud/aiplatform_v1/services/model_garden_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/model_garden_service/transports/grpc.py @@ -55,7 +55,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -75,14 +75,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -92,11 +95,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -122,7 +125,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -163,7 +166,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/model_garden_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/model_garden_service/transports/grpc_asyncio.py index 7e25a4893e..032d339161 100644 --- a/google/cloud/aiplatform_v1/services/model_garden_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/model_garden_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -70,7 +72,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -100,7 +101,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -120,15 +121,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -138,11 +142,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -168,7 +172,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -208,7 +212,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -266,6 +272,16 @@ def get_publisher_model( ) return self._stubs["get_publisher_model"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.get_publisher_model: gapic_v1.method_async.wrap_method( + self.get_publisher_model, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py index 7e1543c3d3..666b1fa688 100644 --- a/google/cloud/aiplatform_v1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -227,7 +228,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, ModelServiceTransport, Callable[..., ModelServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -239,9 +242,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ModelServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelServiceTransport,Callable[..., ModelServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -367,8 +372,8 @@ async def sample_upload_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: raise ValueError( @@ -376,7 +381,10 @@ async def sample_upload_model(): "the individual field arguments should be set." ) - request = model_service.UploadModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.UploadModelRequest): + request = model_service.UploadModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -387,11 +395,9 @@ async def sample_upload_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.upload_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.upload_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -491,8 +497,8 @@ async def sample_get_model(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -500,7 +506,10 @@ async def sample_get_model(): "the individual field arguments should be set." ) - request = model_service.GetModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -509,11 +518,9 @@ async def sample_get_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -601,8 +608,8 @@ async def sample_list_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -610,7 +617,10 @@ async def sample_list_models(): "the individual field arguments should be set." ) - request = model_service.ListModelsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -619,11 +629,9 @@ async def sample_list_models(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_models, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_models + ] # Certain fields should be provided within the metadata header; # add these here. @@ -719,8 +727,8 @@ async def sample_list_model_versions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -728,7 +736,10 @@ async def sample_list_model_versions(): "the individual field arguments should be set." ) - request = model_service.ListModelVersionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelVersionsRequest): + request = model_service.ListModelVersionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -737,11 +748,9 @@ async def sample_list_model_versions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_versions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_versions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -864,8 +873,8 @@ async def sample_update_model(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -873,7 +882,10 @@ async def sample_update_model(): "the individual field arguments should be set." ) - request = model_service.UpdateModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.UpdateModelRequest): + request = model_service.UpdateModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -884,11 +896,9 @@ async def sample_update_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -984,8 +994,8 @@ async def sample_update_explanation_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model]) if request is not None and has_flattened_params: raise ValueError( @@ -993,7 +1003,10 @@ async def sample_update_explanation_dataset(): "the individual field arguments should be set." ) - request = model_service.UpdateExplanationDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.UpdateExplanationDatasetRequest): + request = model_service.UpdateExplanationDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1002,11 +1015,9 @@ async def sample_update_explanation_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_explanation_dataset, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_explanation_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1119,8 +1130,8 @@ async def sample_delete_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1128,7 +1139,10 @@ async def sample_delete_model(): "the individual field arguments should be set." ) - request = model_service.DeleteModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.DeleteModelRequest): + request = model_service.DeleteModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1137,11 +1151,9 @@ async def sample_delete_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1256,8 +1268,8 @@ async def sample_delete_model_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1265,7 +1277,10 @@ async def sample_delete_model_version(): "the individual field arguments should be set." ) - request = model_service.DeleteModelVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.DeleteModelVersionRequest): + request = model_service.DeleteModelVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1274,11 +1289,9 @@ async def sample_delete_model_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_model_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1392,8 +1405,8 @@ async def sample_merge_version_aliases(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, version_aliases]) if request is not None and has_flattened_params: raise ValueError( @@ -1401,7 +1414,10 @@ async def sample_merge_version_aliases(): "the individual field arguments should be set." ) - request = model_service.MergeVersionAliasesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.MergeVersionAliasesRequest): + request = model_service.MergeVersionAliasesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1412,11 +1428,9 @@ async def sample_merge_version_aliases(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.merge_version_aliases, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.merge_version_aliases + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1521,8 +1535,8 @@ async def sample_export_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: raise ValueError( @@ -1530,7 +1544,10 @@ async def sample_export_model(): "the individual field arguments should be set." ) - request = model_service.ExportModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ExportModelRequest): + request = model_service.ExportModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1541,11 +1558,9 @@ async def sample_export_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1660,8 +1675,8 @@ async def sample_copy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, source_model]) if request is not None and has_flattened_params: raise ValueError( @@ -1669,7 +1684,10 @@ async def sample_copy_model(): "the individual field arguments should be set." ) - request = model_service.CopyModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.CopyModelRequest): + request = model_service.CopyModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1680,11 +1698,9 @@ async def sample_copy_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.copy_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.copy_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1787,8 +1803,8 @@ async def sample_import_model_evaluation(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_evaluation]) if request is not None and has_flattened_params: raise ValueError( @@ -1796,7 +1812,10 @@ async def sample_import_model_evaluation(): "the individual field arguments should be set." ) - request = model_service.ImportModelEvaluationRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ImportModelEvaluationRequest): + request = model_service.ImportModelEvaluationRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1807,11 +1826,9 @@ async def sample_import_model_evaluation(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_model_evaluation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_model_evaluation + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1908,8 +1925,8 @@ async def sample_batch_import_model_evaluation_slices(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_evaluation_slices]) if request is not None and has_flattened_params: raise ValueError( @@ -1917,7 +1934,12 @@ async def sample_batch_import_model_evaluation_slices(): "the individual field arguments should be set." ) - request = model_service.BatchImportModelEvaluationSlicesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_service.BatchImportModelEvaluationSlicesRequest + ): + request = model_service.BatchImportModelEvaluationSlicesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1928,11 +1950,9 @@ async def sample_batch_import_model_evaluation_slices(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_import_model_evaluation_slices, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_import_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2029,8 +2049,8 @@ async def sample_batch_import_evaluated_annotations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, evaluated_annotations]) if request is not None and has_flattened_params: raise ValueError( @@ -2038,7 +2058,12 @@ async def sample_batch_import_evaluated_annotations(): "the individual field arguments should be set." ) - request = model_service.BatchImportEvaluatedAnnotationsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_service.BatchImportEvaluatedAnnotationsRequest + ): + request = model_service.BatchImportEvaluatedAnnotationsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2049,11 +2074,9 @@ async def sample_batch_import_evaluated_annotations(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_import_evaluated_annotations, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_import_evaluated_annotations + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2139,8 +2162,8 @@ async def sample_get_model_evaluation(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2148,7 +2171,10 @@ async def sample_get_model_evaluation(): "the individual field arguments should be set." ) - request = model_service.GetModelEvaluationRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetModelEvaluationRequest): + request = model_service.GetModelEvaluationRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2157,11 +2183,9 @@ async def sample_get_model_evaluation(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_evaluation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model_evaluation + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2251,8 +2275,8 @@ async def sample_list_model_evaluations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2260,7 +2284,10 @@ async def sample_list_model_evaluations(): "the individual field arguments should be set." ) - request = model_service.ListModelEvaluationsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelEvaluationsRequest): + request = model_service.ListModelEvaluationsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2269,11 +2296,9 @@ async def sample_list_model_evaluations(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_evaluations, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_evaluations + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2370,8 +2395,8 @@ async def sample_get_model_evaluation_slice(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2379,7 +2404,10 @@ async def sample_get_model_evaluation_slice(): "the individual field arguments should be set." ) - request = model_service.GetModelEvaluationSliceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetModelEvaluationSliceRequest): + request = model_service.GetModelEvaluationSliceRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2388,11 +2416,9 @@ async def sample_get_model_evaluation_slice(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_evaluation_slice, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model_evaluation_slice + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2482,8 +2508,8 @@ async def sample_list_model_evaluation_slices(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2491,7 +2517,10 @@ async def sample_list_model_evaluation_slices(): "the individual field arguments should be set." ) - request = model_service.ListModelEvaluationSlicesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelEvaluationSlicesRequest): + request = model_service.ListModelEvaluationSlicesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2500,11 +2529,9 @@ async def sample_list_model_evaluation_slices(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_evaluation_slices, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py index 1f1e4b4204..ca2b879a8b 100644 --- a/google/cloud/aiplatform_v1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -662,7 +663,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ModelServiceTransport]] = None, + transport: Optional[ + Union[str, ModelServiceTransport, Callable[..., ModelServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -674,9 +677,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ModelServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelServiceTransport,Callable[..., ModelServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -785,8 +790,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ModelServiceTransport], Callable[..., ModelServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ModelServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -877,8 +889,8 @@ def sample_upload_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: raise ValueError( @@ -886,10 +898,8 @@ def sample_upload_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.UploadModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.UploadModelRequest): request = model_service.UploadModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1001,8 +1011,8 @@ def sample_get_model(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1010,10 +1020,8 @@ def sample_get_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.GetModelRequest): request = model_service.GetModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1111,8 +1119,8 @@ def sample_list_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1120,10 +1128,8 @@ def sample_list_models(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ListModelsRequest): request = model_service.ListModelsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1229,8 +1235,8 @@ def sample_list_model_versions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1238,10 +1244,8 @@ def sample_list_model_versions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelVersionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ListModelVersionsRequest): request = model_service.ListModelVersionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1374,8 +1378,8 @@ def sample_update_model(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1383,10 +1387,8 @@ def sample_update_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.UpdateModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.UpdateModelRequest): request = model_service.UpdateModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1494,8 +1496,8 @@ def sample_update_explanation_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model]) if request is not None and has_flattened_params: raise ValueError( @@ -1503,10 +1505,8 @@ def sample_update_explanation_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.UpdateExplanationDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.UpdateExplanationDatasetRequest): request = model_service.UpdateExplanationDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1631,8 +1631,8 @@ def sample_delete_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1640,10 +1640,8 @@ def sample_delete_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.DeleteModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.DeleteModelRequest): request = model_service.DeleteModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1768,8 +1766,8 @@ def sample_delete_model_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1777,10 +1775,8 @@ def sample_delete_model_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.DeleteModelVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.DeleteModelVersionRequest): request = model_service.DeleteModelVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1904,8 +1900,8 @@ def sample_merge_version_aliases(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, version_aliases]) if request is not None and has_flattened_params: raise ValueError( @@ -1913,10 +1909,8 @@ def sample_merge_version_aliases(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.MergeVersionAliasesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.MergeVersionAliasesRequest): request = model_service.MergeVersionAliasesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2033,8 +2027,8 @@ def sample_export_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: raise ValueError( @@ -2042,10 +2036,8 @@ def sample_export_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ExportModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ExportModelRequest): request = model_service.ExportModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2172,8 +2164,8 @@ def sample_copy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, source_model]) if request is not None and has_flattened_params: raise ValueError( @@ -2181,10 +2173,8 @@ def sample_copy_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.CopyModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.CopyModelRequest): request = model_service.CopyModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2299,8 +2289,8 @@ def sample_import_model_evaluation(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_evaluation]) if request is not None and has_flattened_params: raise ValueError( @@ -2308,10 +2298,8 @@ def sample_import_model_evaluation(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ImportModelEvaluationRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ImportModelEvaluationRequest): request = model_service.ImportModelEvaluationRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2420,8 +2408,8 @@ def sample_batch_import_model_evaluation_slices(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_evaluation_slices]) if request is not None and has_flattened_params: raise ValueError( @@ -2429,10 +2417,8 @@ def sample_batch_import_model_evaluation_slices(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.BatchImportModelEvaluationSlicesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_service.BatchImportModelEvaluationSlicesRequest ): @@ -2545,8 +2531,8 @@ def sample_batch_import_evaluated_annotations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, evaluated_annotations]) if request is not None and has_flattened_params: raise ValueError( @@ -2554,10 +2540,8 @@ def sample_batch_import_evaluated_annotations(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.BatchImportEvaluatedAnnotationsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_service.BatchImportEvaluatedAnnotationsRequest ): @@ -2659,8 +2643,8 @@ def sample_get_model_evaluation(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2668,10 +2652,8 @@ def sample_get_model_evaluation(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelEvaluationRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.GetModelEvaluationRequest): request = model_service.GetModelEvaluationRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2771,8 +2753,8 @@ def sample_list_model_evaluations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2780,10 +2762,8 @@ def sample_list_model_evaluations(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelEvaluationsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ListModelEvaluationsRequest): request = model_service.ListModelEvaluationsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2890,8 +2870,8 @@ def sample_get_model_evaluation_slice(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2899,10 +2879,8 @@ def sample_get_model_evaluation_slice(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelEvaluationSliceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.GetModelEvaluationSliceRequest): request = model_service.GetModelEvaluationSliceRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3004,8 +2982,8 @@ def sample_list_model_evaluation_slices(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3013,10 +2991,8 @@ def sample_list_model_evaluation_slices(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelEvaluationSlicesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ListModelEvaluationSlicesRequest): request = model_service.ListModelEvaluationSlicesRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py index e7e20ad52e..a81b4cc9fa 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py @@ -60,7 +60,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -80,14 +80,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -97,11 +100,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -128,7 +131,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -169,7 +172,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py index a97e0ae17a..83db4b2428 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -75,7 +77,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -105,7 +106,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -125,15 +126,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -143,11 +147,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -174,7 +178,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -214,7 +218,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -797,6 +803,101 @@ def list_model_evaluation_slices( ) return self._stubs["list_model_evaluation_slices"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.upload_model: gapic_v1.method_async.wrap_method( + self.upload_model, + default_timeout=None, + client_info=client_info, + ), + self.get_model: gapic_v1.method_async.wrap_method( + self.get_model, + default_timeout=None, + client_info=client_info, + ), + self.list_models: gapic_v1.method_async.wrap_method( + self.list_models, + default_timeout=None, + client_info=client_info, + ), + self.list_model_versions: gapic_v1.method_async.wrap_method( + self.list_model_versions, + default_timeout=None, + client_info=client_info, + ), + self.update_model: gapic_v1.method_async.wrap_method( + self.update_model, + default_timeout=None, + client_info=client_info, + ), + self.update_explanation_dataset: gapic_v1.method_async.wrap_method( + self.update_explanation_dataset, + default_timeout=None, + client_info=client_info, + ), + self.delete_model: gapic_v1.method_async.wrap_method( + self.delete_model, + default_timeout=None, + client_info=client_info, + ), + self.delete_model_version: gapic_v1.method_async.wrap_method( + self.delete_model_version, + default_timeout=None, + client_info=client_info, + ), + self.merge_version_aliases: gapic_v1.method_async.wrap_method( + self.merge_version_aliases, + default_timeout=None, + client_info=client_info, + ), + self.export_model: gapic_v1.method_async.wrap_method( + self.export_model, + default_timeout=None, + client_info=client_info, + ), + self.copy_model: gapic_v1.method_async.wrap_method( + self.copy_model, + default_timeout=None, + client_info=client_info, + ), + self.import_model_evaluation: gapic_v1.method_async.wrap_method( + self.import_model_evaluation, + default_timeout=None, + client_info=client_info, + ), + self.batch_import_model_evaluation_slices: gapic_v1.method_async.wrap_method( + self.batch_import_model_evaluation_slices, + default_timeout=None, + client_info=client_info, + ), + self.batch_import_evaluated_annotations: gapic_v1.method_async.wrap_method( + self.batch_import_evaluated_annotations, + default_timeout=None, + client_info=client_info, + ), + self.get_model_evaluation: gapic_v1.method_async.wrap_method( + self.get_model_evaluation, + default_timeout=None, + client_info=client_info, + ), + self.list_model_evaluations: gapic_v1.method_async.wrap_method( + self.list_model_evaluations, + default_timeout=None, + client_info=client_info, + ), + self.get_model_evaluation_slice: gapic_v1.method_async.wrap_method( + self.get_model_evaluation_slice, + default_timeout=None, + client_info=client_info, + ), + self.list_model_evaluation_slices: gapic_v1.method_async.wrap_method( + self.list_model_evaluation_slices, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/notebook_service/async_client.py b/google/cloud/aiplatform_v1/services/notebook_service/async_client.py index 4193466b6b..935997cddb 100644 --- a/google/cloud/aiplatform_v1/services/notebook_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/notebook_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -223,7 +224,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, NotebookServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, NotebookServiceTransport, Callable[..., NotebookServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -235,9 +240,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.NotebookServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,NotebookServiceTransport,Callable[..., NotebookServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the NotebookServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -378,8 +385,8 @@ async def sample_create_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, notebook_runtime_template, notebook_runtime_template_id] ) @@ -389,7 +396,12 @@ async def sample_create_notebook_runtime_template(): "the individual field arguments should be set." ) - request = notebook_service.CreateNotebookRuntimeTemplateRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, notebook_service.CreateNotebookRuntimeTemplateRequest + ): + request = notebook_service.CreateNotebookRuntimeTemplateRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -402,11 +414,9 @@ async def sample_create_notebook_runtime_template(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_notebook_runtime_template, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_notebook_runtime_template + ] # Certain fields should be provided within the metadata header; # add these here. @@ -503,8 +513,8 @@ async def sample_get_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -512,7 +522,10 @@ async def sample_get_notebook_runtime_template(): "the individual field arguments should be set." ) - request = notebook_service.GetNotebookRuntimeTemplateRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.GetNotebookRuntimeTemplateRequest): + request = notebook_service.GetNotebookRuntimeTemplateRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -521,11 +534,9 @@ async def sample_get_notebook_runtime_template(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_notebook_runtime_template, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_notebook_runtime_template + ] # Certain fields should be provided within the metadata header; # add these here. @@ -615,8 +626,8 @@ async def sample_list_notebook_runtime_templates(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -624,7 +635,12 @@ async def sample_list_notebook_runtime_templates(): "the individual field arguments should be set." ) - request = notebook_service.ListNotebookRuntimeTemplatesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, notebook_service.ListNotebookRuntimeTemplatesRequest + ): + request = notebook_service.ListNotebookRuntimeTemplatesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -633,11 +649,9 @@ async def sample_list_notebook_runtime_templates(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_notebook_runtime_templates, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_notebook_runtime_templates + ] # Certain fields should be provided within the metadata header; # add these here. @@ -746,8 +760,8 @@ async def sample_delete_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -755,7 +769,12 @@ async def sample_delete_notebook_runtime_template(): "the individual field arguments should be set." ) - request = notebook_service.DeleteNotebookRuntimeTemplateRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, notebook_service.DeleteNotebookRuntimeTemplateRequest + ): + request = notebook_service.DeleteNotebookRuntimeTemplateRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -764,11 +783,9 @@ async def sample_delete_notebook_runtime_template(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_notebook_runtime_template, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_notebook_runtime_template + ] # Certain fields should be provided within the metadata header; # add these here. @@ -905,8 +922,8 @@ async def sample_assign_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, notebook_runtime_template, notebook_runtime, notebook_runtime_id] ) @@ -916,7 +933,10 @@ async def sample_assign_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.AssignNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.AssignNotebookRuntimeRequest): + request = notebook_service.AssignNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -931,11 +951,9 @@ async def sample_assign_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.assign_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.assign_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1035,8 +1053,8 @@ async def sample_get_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1044,7 +1062,10 @@ async def sample_get_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.GetNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.GetNotebookRuntimeRequest): + request = notebook_service.GetNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1053,11 +1074,9 @@ async def sample_get_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1147,8 +1166,8 @@ async def sample_list_notebook_runtimes(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1156,7 +1175,10 @@ async def sample_list_notebook_runtimes(): "the individual field arguments should be set." ) - request = notebook_service.ListNotebookRuntimesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.ListNotebookRuntimesRequest): + request = notebook_service.ListNotebookRuntimesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1165,11 +1187,9 @@ async def sample_list_notebook_runtimes(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_notebook_runtimes, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_notebook_runtimes + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1282,8 +1302,8 @@ async def sample_delete_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1291,7 +1311,10 @@ async def sample_delete_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.DeleteNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.DeleteNotebookRuntimeRequest): + request = notebook_service.DeleteNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1300,11 +1323,9 @@ async def sample_delete_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1408,8 +1429,8 @@ async def sample_upgrade_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1417,7 +1438,10 @@ async def sample_upgrade_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.UpgradeNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.UpgradeNotebookRuntimeRequest): + request = notebook_service.UpgradeNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1426,11 +1450,9 @@ async def sample_upgrade_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.upgrade_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.upgrade_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1534,8 +1556,8 @@ async def sample_start_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1543,7 +1565,10 @@ async def sample_start_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.StartNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.StartNotebookRuntimeRequest): + request = notebook_service.StartNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1552,11 +1577,9 @@ async def sample_start_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.start_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.start_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/notebook_service/client.py b/google/cloud/aiplatform_v1/services/notebook_service/client.py index 2131105155..6f1b8f5bb0 100644 --- a/google/cloud/aiplatform_v1/services/notebook_service/client.py +++ b/google/cloud/aiplatform_v1/services/notebook_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -609,7 +610,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, NotebookServiceTransport]] = None, + transport: Optional[ + Union[ + str, NotebookServiceTransport, Callable[..., NotebookServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -621,9 +626,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, NotebookServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,NotebookServiceTransport,Callable[..., NotebookServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the NotebookServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -735,8 +742,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[NotebookServiceTransport], Callable[..., NotebookServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., NotebookServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -842,8 +856,8 @@ def sample_create_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, notebook_runtime_template, notebook_runtime_template_id] ) @@ -853,10 +867,8 @@ def sample_create_notebook_runtime_template(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.CreateNotebookRuntimeTemplateRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, notebook_service.CreateNotebookRuntimeTemplateRequest ): @@ -971,8 +983,8 @@ def sample_get_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -980,10 +992,8 @@ def sample_get_notebook_runtime_template(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.GetNotebookRuntimeTemplateRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.GetNotebookRuntimeTemplateRequest): request = notebook_service.GetNotebookRuntimeTemplateRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1085,8 +1095,8 @@ def sample_list_notebook_runtime_templates(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1094,10 +1104,8 @@ def sample_list_notebook_runtime_templates(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.ListNotebookRuntimeTemplatesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, notebook_service.ListNotebookRuntimeTemplatesRequest ): @@ -1220,8 +1228,8 @@ def sample_delete_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1229,10 +1237,8 @@ def sample_delete_notebook_runtime_template(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.DeleteNotebookRuntimeTemplateRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, notebook_service.DeleteNotebookRuntimeTemplateRequest ): @@ -1383,8 +1389,8 @@ def sample_assign_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, notebook_runtime_template, notebook_runtime, notebook_runtime_id] ) @@ -1394,10 +1400,8 @@ def sample_assign_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.AssignNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.AssignNotebookRuntimeRequest): request = notebook_service.AssignNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1513,8 +1517,8 @@ def sample_get_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1522,10 +1526,8 @@ def sample_get_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.GetNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.GetNotebookRuntimeRequest): request = notebook_service.GetNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1625,8 +1627,8 @@ def sample_list_notebook_runtimes(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1634,10 +1636,8 @@ def sample_list_notebook_runtimes(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.ListNotebookRuntimesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.ListNotebookRuntimesRequest): request = notebook_service.ListNotebookRuntimesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1760,8 +1760,8 @@ def sample_delete_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1769,10 +1769,8 @@ def sample_delete_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.DeleteNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.DeleteNotebookRuntimeRequest): request = notebook_service.DeleteNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1886,8 +1884,8 @@ def sample_upgrade_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1895,10 +1893,8 @@ def sample_upgrade_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.UpgradeNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.UpgradeNotebookRuntimeRequest): request = notebook_service.UpgradeNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2012,8 +2008,8 @@ def sample_start_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2021,10 +2017,8 @@ def sample_start_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.StartNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.StartNotebookRuntimeRequest): request = notebook_service.StartNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/notebook_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/notebook_service/transports/grpc.py index 337a95d69c..a0a3c29990 100644 --- a/google/cloud/aiplatform_v1/services/notebook_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/notebook_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/notebook_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/notebook_service/transports/grpc_asyncio.py index 02593589cc..14727cdd13 100644 --- a/google/cloud/aiplatform_v1/services/notebook_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/notebook_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -559,6 +565,61 @@ def start_notebook_runtime( ) return self._stubs["start_notebook_runtime"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_notebook_runtime_template: gapic_v1.method_async.wrap_method( + self.create_notebook_runtime_template, + default_timeout=None, + client_info=client_info, + ), + self.get_notebook_runtime_template: gapic_v1.method_async.wrap_method( + self.get_notebook_runtime_template, + default_timeout=None, + client_info=client_info, + ), + self.list_notebook_runtime_templates: gapic_v1.method_async.wrap_method( + self.list_notebook_runtime_templates, + default_timeout=None, + client_info=client_info, + ), + self.delete_notebook_runtime_template: gapic_v1.method_async.wrap_method( + self.delete_notebook_runtime_template, + default_timeout=None, + client_info=client_info, + ), + self.assign_notebook_runtime: gapic_v1.method_async.wrap_method( + self.assign_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.get_notebook_runtime: gapic_v1.method_async.wrap_method( + self.get_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.list_notebook_runtimes: gapic_v1.method_async.wrap_method( + self.list_notebook_runtimes, + default_timeout=None, + client_info=client_info, + ), + self.delete_notebook_runtime: gapic_v1.method_async.wrap_method( + self.delete_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.upgrade_notebook_runtime: gapic_v1.method_async.wrap_method( + self.upgrade_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.start_notebook_runtime: gapic_v1.method_async.wrap_method( + self.start_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/persistent_resource_service/async_client.py b/google/cloud/aiplatform_v1/services/persistent_resource_service/async_client.py index 9d279e3317..586d39e546 100644 --- a/google/cloud/aiplatform_v1/services/persistent_resource_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/persistent_resource_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -228,7 +229,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, PersistentResourceServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + PersistentResourceServiceTransport, + Callable[..., PersistentResourceServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -240,9 +247,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.PersistentResourceServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PersistentResourceServiceTransport,Callable[..., PersistentResourceServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PersistentResourceServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -384,8 +393,8 @@ async def sample_create_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, persistent_resource, persistent_resource_id] ) @@ -395,7 +404,14 @@ async def sample_create_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.CreatePersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.CreatePersistentResourceRequest + ): + request = persistent_resource_service.CreatePersistentResourceRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -408,11 +424,9 @@ async def sample_create_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. @@ -509,8 +523,8 @@ async def sample_get_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -518,7 +532,12 @@ async def sample_get_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.GetPersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.GetPersistentResourceRequest + ): + request = persistent_resource_service.GetPersistentResourceRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -527,11 +546,9 @@ async def sample_get_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. @@ -621,8 +638,8 @@ async def sample_list_persistent_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -630,7 +647,14 @@ async def sample_list_persistent_resources(): "the individual field arguments should be set." ) - request = persistent_resource_service.ListPersistentResourcesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.ListPersistentResourcesRequest + ): + request = persistent_resource_service.ListPersistentResourcesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -639,11 +663,9 @@ async def sample_list_persistent_resources(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_persistent_resources, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_persistent_resources + ] # Certain fields should be provided within the metadata header; # add these here. @@ -752,8 +774,8 @@ async def sample_delete_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -761,7 +783,14 @@ async def sample_delete_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.DeletePersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.DeletePersistentResourceRequest + ): + request = persistent_resource_service.DeletePersistentResourceRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -770,11 +799,9 @@ async def sample_delete_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. @@ -888,8 +915,8 @@ async def sample_update_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([persistent_resource, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -897,7 +924,14 @@ async def sample_update_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.UpdatePersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.UpdatePersistentResourceRequest + ): + request = persistent_resource_service.UpdatePersistentResourceRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -908,11 +942,9 @@ async def sample_update_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1016,8 +1048,8 @@ async def sample_reboot_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1025,7 +1057,14 @@ async def sample_reboot_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.RebootPersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.RebootPersistentResourceRequest + ): + request = persistent_resource_service.RebootPersistentResourceRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1034,11 +1073,9 @@ async def sample_reboot_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.reboot_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.reboot_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/persistent_resource_service/client.py b/google/cloud/aiplatform_v1/services/persistent_resource_service/client.py index 682cd973f6..29982bf60e 100644 --- a/google/cloud/aiplatform_v1/services/persistent_resource_service/client.py +++ b/google/cloud/aiplatform_v1/services/persistent_resource_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -567,7 +568,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, PersistentResourceServiceTransport]] = None, + transport: Optional[ + Union[ + str, + PersistentResourceServiceTransport, + Callable[..., PersistentResourceServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -579,9 +586,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, PersistentResourceServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PersistentResourceServiceTransport,Callable[..., PersistentResourceServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PersistentResourceServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -695,8 +704,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[PersistentResourceServiceTransport], + Callable[..., PersistentResourceServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., PersistentResourceServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -803,8 +820,8 @@ def sample_create_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, persistent_resource, persistent_resource_id] ) @@ -814,10 +831,8 @@ def sample_create_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.CreatePersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.CreatePersistentResourceRequest ): @@ -934,8 +949,8 @@ def sample_get_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -943,10 +958,8 @@ def sample_get_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.GetPersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.GetPersistentResourceRequest ): @@ -1048,8 +1061,8 @@ def sample_list_persistent_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1057,10 +1070,8 @@ def sample_list_persistent_resources(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.ListPersistentResourcesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.ListPersistentResourcesRequest ): @@ -1185,8 +1196,8 @@ def sample_delete_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1194,10 +1205,8 @@ def sample_delete_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.DeletePersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.DeletePersistentResourceRequest ): @@ -1327,8 +1336,8 @@ def sample_update_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([persistent_resource, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1336,10 +1345,8 @@ def sample_update_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.UpdatePersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.UpdatePersistentResourceRequest ): @@ -1461,8 +1468,8 @@ def sample_reboot_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1470,10 +1477,8 @@ def sample_reboot_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.RebootPersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.RebootPersistentResourceRequest ): diff --git a/google/cloud/aiplatform_v1/services/persistent_resource_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/persistent_resource_service/transports/grpc.py index ca685bcfba..3b14171233 100644 --- a/google/cloud/aiplatform_v1/services/persistent_resource_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/persistent_resource_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/persistent_resource_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/persistent_resource_service/transports/grpc_asyncio.py index bd6b37ef34..d9befa6eac 100644 --- a/google/cloud/aiplatform_v1/services/persistent_resource_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/persistent_resource_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -430,6 +436,41 @@ def reboot_persistent_resource( ) return self._stubs["reboot_persistent_resource"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_persistent_resource: gapic_v1.method_async.wrap_method( + self.create_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + self.get_persistent_resource: gapic_v1.method_async.wrap_method( + self.get_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + self.list_persistent_resources: gapic_v1.method_async.wrap_method( + self.list_persistent_resources, + default_timeout=None, + client_info=client_info, + ), + self.delete_persistent_resource: gapic_v1.method_async.wrap_method( + self.delete_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + self.update_persistent_resource: gapic_v1.method_async.wrap_method( + self.update_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + self.reboot_persistent_resource: gapic_v1.method_async.wrap_method( + self.reboot_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py index edde11cd61..defade15c3 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -235,7 +236,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, PipelineServiceTransport, Callable[..., PipelineServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -247,9 +252,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.PipelineServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PipelineServiceTransport,Callable[..., PipelineServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PipelineServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -379,8 +386,8 @@ async def sample_create_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: raise ValueError( @@ -388,7 +395,10 @@ async def sample_create_training_pipeline(): "the individual field arguments should be set." ) - request = pipeline_service.CreateTrainingPipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.CreateTrainingPipelineRequest): + request = pipeline_service.CreateTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -399,11 +409,9 @@ async def sample_create_training_pipeline(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_training_pipeline, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_training_pipeline + ] # Certain fields should be provided within the metadata header; # add these here. @@ -493,8 +501,8 @@ async def sample_get_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -502,7 +510,10 @@ async def sample_get_training_pipeline(): "the individual field arguments should be set." ) - request = pipeline_service.GetTrainingPipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.GetTrainingPipelineRequest): + request = pipeline_service.GetTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -511,11 +522,9 @@ async def sample_get_training_pipeline(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_training_pipeline, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_training_pipeline + ] # Certain fields should be provided within the metadata header; # add these here. @@ -605,8 +614,8 @@ async def sample_list_training_pipelines(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -614,7 +623,10 @@ async def sample_list_training_pipelines(): "the individual field arguments should be set." ) - request = pipeline_service.ListTrainingPipelinesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.ListTrainingPipelinesRequest): + request = pipeline_service.ListTrainingPipelinesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -623,11 +635,9 @@ async def sample_list_training_pipelines(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_training_pipelines, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_training_pipelines + ] # Certain fields should be provided within the metadata header; # add these here. @@ -736,8 +746,8 @@ async def sample_delete_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -745,7 +755,10 @@ async def sample_delete_training_pipeline(): "the individual field arguments should be set." ) - request = pipeline_service.DeleteTrainingPipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.DeleteTrainingPipelineRequest): + request = pipeline_service.DeleteTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -754,11 +767,9 @@ async def sample_delete_training_pipeline(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_training_pipeline, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_training_pipeline + ] # Certain fields should be provided within the metadata header; # add these here. @@ -855,8 +866,8 @@ async def sample_cancel_training_pipeline(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -864,7 +875,10 @@ async def sample_cancel_training_pipeline(): "the individual field arguments should be set." ) - request = pipeline_service.CancelTrainingPipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.CancelTrainingPipelineRequest): + request = pipeline_service.CancelTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -873,11 +887,9 @@ async def sample_cancel_training_pipeline(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_training_pipeline, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_training_pipeline + ] # Certain fields should be provided within the metadata header; # add these here. @@ -979,8 +991,8 @@ async def sample_create_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) if request is not None and has_flattened_params: raise ValueError( @@ -988,7 +1000,10 @@ async def sample_create_pipeline_job(): "the individual field arguments should be set." ) - request = pipeline_service.CreatePipelineJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.CreatePipelineJobRequest): + request = pipeline_service.CreatePipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1001,11 +1016,9 @@ async def sample_create_pipeline_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_pipeline_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_pipeline_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1088,8 +1101,8 @@ async def sample_get_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1097,7 +1110,10 @@ async def sample_get_pipeline_job(): "the individual field arguments should be set." ) - request = pipeline_service.GetPipelineJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.GetPipelineJobRequest): + request = pipeline_service.GetPipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1106,11 +1122,9 @@ async def sample_get_pipeline_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_pipeline_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_pipeline_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1198,8 +1212,8 @@ async def sample_list_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1207,7 +1221,10 @@ async def sample_list_pipeline_jobs(): "the individual field arguments should be set." ) - request = pipeline_service.ListPipelineJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.ListPipelineJobsRequest): + request = pipeline_service.ListPipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1216,11 +1233,9 @@ async def sample_list_pipeline_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_pipeline_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_pipeline_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1329,8 +1344,8 @@ async def sample_delete_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1338,7 +1353,10 @@ async def sample_delete_pipeline_job(): "the individual field arguments should be set." ) - request = pipeline_service.DeletePipelineJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.DeletePipelineJobRequest): + request = pipeline_service.DeletePipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1347,11 +1365,9 @@ async def sample_delete_pipeline_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_pipeline_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_pipeline_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1464,8 +1480,8 @@ async def sample_batch_delete_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, names]) if request is not None and has_flattened_params: raise ValueError( @@ -1473,7 +1489,10 @@ async def sample_batch_delete_pipeline_jobs(): "the individual field arguments should be set." ) - request = pipeline_service.BatchDeletePipelineJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.BatchDeletePipelineJobsRequest): + request = pipeline_service.BatchDeletePipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1484,11 +1503,9 @@ async def sample_batch_delete_pipeline_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_delete_pipeline_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_delete_pipeline_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1584,8 +1601,8 @@ async def sample_cancel_pipeline_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1593,7 +1610,10 @@ async def sample_cancel_pipeline_job(): "the individual field arguments should be set." ) - request = pipeline_service.CancelPipelineJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.CancelPipelineJobRequest): + request = pipeline_service.CancelPipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1602,11 +1622,9 @@ async def sample_cancel_pipeline_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_pipeline_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_pipeline_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1713,8 +1731,8 @@ async def sample_batch_cancel_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, names]) if request is not None and has_flattened_params: raise ValueError( @@ -1722,7 +1740,10 @@ async def sample_batch_cancel_pipeline_jobs(): "the individual field arguments should be set." ) - request = pipeline_service.BatchCancelPipelineJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.BatchCancelPipelineJobsRequest): + request = pipeline_service.BatchCancelPipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1733,11 +1754,9 @@ async def sample_batch_cancel_pipeline_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_cancel_pipeline_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_cancel_pipeline_jobs + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py index 57cc556308..a06d5f4e4a 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -729,7 +730,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, PipelineServiceTransport]] = None, + transport: Optional[ + Union[ + str, PipelineServiceTransport, Callable[..., PipelineServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -741,9 +746,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, PipelineServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PipelineServiceTransport,Callable[..., PipelineServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PipelineServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -855,8 +862,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[PipelineServiceTransport], Callable[..., PipelineServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., PipelineServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -951,8 +965,8 @@ def sample_create_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: raise ValueError( @@ -960,10 +974,8 @@ def sample_create_training_pipeline(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CreateTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.CreateTrainingPipelineRequest): request = pipeline_service.CreateTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1065,8 +1077,8 @@ def sample_get_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1074,10 +1086,8 @@ def sample_get_training_pipeline(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.GetTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.GetTrainingPipelineRequest): request = pipeline_service.GetTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1177,8 +1187,8 @@ def sample_list_training_pipelines(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1186,10 +1196,8 @@ def sample_list_training_pipelines(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.ListTrainingPipelinesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.ListTrainingPipelinesRequest): request = pipeline_service.ListTrainingPipelinesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1308,8 +1316,8 @@ def sample_delete_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1317,10 +1325,8 @@ def sample_delete_training_pipeline(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.DeleteTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.DeleteTrainingPipelineRequest): request = pipeline_service.DeleteTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1427,8 +1433,8 @@ def sample_cancel_training_pipeline(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1436,10 +1442,8 @@ def sample_cancel_training_pipeline(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CancelTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.CancelTrainingPipelineRequest): request = pipeline_service.CancelTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1551,8 +1555,8 @@ def sample_create_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1560,10 +1564,8 @@ def sample_create_pipeline_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CreatePipelineJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.CreatePipelineJobRequest): request = pipeline_service.CreatePipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1660,8 +1662,8 @@ def sample_get_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1669,10 +1671,8 @@ def sample_get_pipeline_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.GetPipelineJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.GetPipelineJobRequest): request = pipeline_service.GetPipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1770,8 +1770,8 @@ def sample_list_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1779,10 +1779,8 @@ def sample_list_pipeline_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.ListPipelineJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.ListPipelineJobsRequest): request = pipeline_service.ListPipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1901,8 +1899,8 @@ def sample_delete_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1910,10 +1908,8 @@ def sample_delete_pipeline_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.DeletePipelineJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.DeletePipelineJobRequest): request = pipeline_service.DeletePipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2036,8 +2032,8 @@ def sample_batch_delete_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, names]) if request is not None and has_flattened_params: raise ValueError( @@ -2045,10 +2041,8 @@ def sample_batch_delete_pipeline_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.BatchDeletePipelineJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.BatchDeletePipelineJobsRequest): request = pipeline_service.BatchDeletePipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2158,8 +2152,8 @@ def sample_cancel_pipeline_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2167,10 +2161,8 @@ def sample_cancel_pipeline_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CancelPipelineJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.CancelPipelineJobRequest): request = pipeline_service.CancelPipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2287,8 +2279,8 @@ def sample_batch_cancel_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, names]) if request is not None and has_flattened_params: raise ValueError( @@ -2296,10 +2288,8 @@ def sample_batch_cancel_pipeline_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.BatchCancelPipelineJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.BatchCancelPipelineJobsRequest): request = pipeline_service.BatchCancelPipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py index 45592bd780..f4d15e6032 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py @@ -63,7 +63,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -83,14 +83,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -100,11 +103,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -131,7 +134,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -172,7 +175,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py index b74ea86642..ace50d52a8 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -78,7 +80,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -108,7 +109,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -128,15 +129,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -146,11 +150,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -177,7 +181,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -217,7 +221,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -643,6 +649,71 @@ def batch_cancel_pipeline_jobs( ) return self._stubs["batch_cancel_pipeline_jobs"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_training_pipeline: gapic_v1.method_async.wrap_method( + self.create_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.get_training_pipeline: gapic_v1.method_async.wrap_method( + self.get_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.list_training_pipelines: gapic_v1.method_async.wrap_method( + self.list_training_pipelines, + default_timeout=None, + client_info=client_info, + ), + self.delete_training_pipeline: gapic_v1.method_async.wrap_method( + self.delete_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.cancel_training_pipeline: gapic_v1.method_async.wrap_method( + self.cancel_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.create_pipeline_job: gapic_v1.method_async.wrap_method( + self.create_pipeline_job, + default_timeout=None, + client_info=client_info, + ), + self.get_pipeline_job: gapic_v1.method_async.wrap_method( + self.get_pipeline_job, + default_timeout=None, + client_info=client_info, + ), + self.list_pipeline_jobs: gapic_v1.method_async.wrap_method( + self.list_pipeline_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_pipeline_job: gapic_v1.method_async.wrap_method( + self.delete_pipeline_job, + default_timeout=None, + client_info=client_info, + ), + self.batch_delete_pipeline_jobs: gapic_v1.method_async.wrap_method( + self.batch_delete_pipeline_jobs, + default_timeout=None, + client_info=client_info, + ), + self.cancel_pipeline_job: gapic_v1.method_async.wrap_method( + self.cancel_pipeline_job, + default_timeout=None, + client_info=client_info, + ), + self.batch_cancel_pipeline_jobs: gapic_v1.method_async.wrap_method( + self.batch_cancel_pipeline_jobs, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py index 56bb7cadaa..f94ec84b6b 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -207,7 +208,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + PredictionServiceTransport, + Callable[..., PredictionServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -219,9 +226,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.PredictionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PredictionServiceTransport,Callable[..., PredictionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PredictionServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -363,8 +372,8 @@ async def sample_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: raise ValueError( @@ -372,7 +381,10 @@ async def sample_predict(): "the individual field arguments should be set." ) - request = prediction_service.PredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.PredictRequest): + request = prediction_service.PredictRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -385,11 +397,7 @@ async def sample_predict(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[self._client._transport.predict] # Certain fields should be provided within the metadata header; # add these here. @@ -554,8 +562,8 @@ async def sample_raw_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: raise ValueError( @@ -563,7 +571,10 @@ async def sample_raw_predict(): "the individual field arguments should be set." ) - request = prediction_service.RawPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.RawPredictRequest): + request = prediction_service.RawPredictRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -574,11 +585,9 @@ async def sample_raw_predict(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.raw_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -718,8 +727,8 @@ async def sample_stream_raw_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: raise ValueError( @@ -727,7 +736,10 @@ async def sample_stream_raw_predict(): "the individual field arguments should be set." ) - request = prediction_service.StreamRawPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.StreamRawPredictRequest): + request = prediction_service.StreamRawPredictRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -738,11 +750,9 @@ async def sample_stream_raw_predict(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stream_raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stream_raw_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -819,15 +829,16 @@ async def sample_direct_predict(): """ # Create or coerce a protobuf request object. - request = prediction_service.DirectPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.DirectPredictRequest): + request = prediction_service.DirectPredictRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.direct_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.direct_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -905,15 +916,16 @@ async def sample_direct_raw_predict(): """ # Create or coerce a protobuf request object. - request = prediction_service.DirectRawPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.DirectRawPredictRequest): + request = prediction_service.DirectRawPredictRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.direct_raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.direct_raw_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1010,11 +1022,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stream_direct_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stream_direct_predict + ] # Validate the universe domain. self._client._validate_universe_domain() @@ -1110,11 +1120,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stream_direct_raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stream_direct_raw_predict + ] # Validate the universe domain. self._client._validate_universe_domain() @@ -1204,11 +1212,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.streaming_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.streaming_predict + ] # Validate the universe domain. self._client._validate_universe_domain() @@ -1286,15 +1292,16 @@ async def sample_server_streaming_predict(): """ # Create or coerce a protobuf request object. - request = prediction_service.StreamingPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.StreamingPredictRequest): + request = prediction_service.StreamingPredictRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.server_streaming_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.server_streaming_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1396,11 +1403,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.streaming_raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.streaming_raw_predict + ] # Validate the universe domain. self._client._validate_universe_domain() @@ -1530,8 +1535,8 @@ async def sample_explain(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1539,7 +1544,10 @@ async def sample_explain(): "the individual field arguments should be set." ) - request = prediction_service.ExplainRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.ExplainRequest): + request = prediction_service.ExplainRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1554,11 +1562,7 @@ async def sample_explain(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.explain, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[self._client._transport.explain] # Certain fields should be provided within the metadata header; # add these here. @@ -1659,8 +1663,8 @@ async def sample_generate_content(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -1668,7 +1672,10 @@ async def sample_generate_content(): "the individual field arguments should be set." ) - request = prediction_service.GenerateContentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.GenerateContentRequest): + request = prediction_service.GenerateContentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1679,11 +1686,9 @@ async def sample_generate_content(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.generate_content, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.generate_content + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1786,8 +1791,8 @@ async def sample_stream_generate_content(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -1795,7 +1800,10 @@ async def sample_stream_generate_content(): "the individual field arguments should be set." ) - request = prediction_service.GenerateContentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.GenerateContentRequest): + request = prediction_service.GenerateContentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1806,11 +1814,9 @@ async def sample_stream_generate_content(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stream_generate_content, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stream_generate_content + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py index 5b12db202d..4ac3d73427 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -561,7 +562,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, PredictionServiceTransport]] = None, + transport: Optional[ + Union[ + str, + PredictionServiceTransport, + Callable[..., PredictionServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -573,9 +580,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, PredictionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PredictionServiceTransport,Callable[..., PredictionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PredictionServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -687,8 +696,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[PredictionServiceTransport], + Callable[..., PredictionServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., PredictionServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -795,8 +812,8 @@ def sample_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: raise ValueError( @@ -804,10 +821,8 @@ def sample_predict(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.PredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.PredictRequest): request = prediction_service.PredictRequest(request) # If we have keyword arguments corresponding to fields on the @@ -986,8 +1001,8 @@ def sample_raw_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: raise ValueError( @@ -995,10 +1010,8 @@ def sample_raw_predict(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.RawPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.RawPredictRequest): request = prediction_service.RawPredictRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1150,8 +1163,8 @@ def sample_stream_raw_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: raise ValueError( @@ -1159,10 +1172,8 @@ def sample_stream_raw_predict(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.StreamRawPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.StreamRawPredictRequest): request = prediction_service.StreamRawPredictRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1251,10 +1262,8 @@ def sample_direct_predict(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.DirectPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.DirectPredictRequest): request = prediction_service.DirectPredictRequest(request) @@ -1338,10 +1347,8 @@ def sample_direct_raw_predict(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.DirectRawPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.DirectRawPredictRequest): request = prediction_service.DirectRawPredictRequest(request) @@ -1708,10 +1715,8 @@ def sample_server_streaming_predict(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.StreamingPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.StreamingPredictRequest): request = prediction_service.StreamingPredictRequest(request) @@ -1949,8 +1954,8 @@ def sample_explain(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1958,10 +1963,8 @@ def sample_explain(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.ExplainRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.ExplainRequest): request = prediction_service.ExplainRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2078,8 +2081,8 @@ def sample_generate_content(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -2087,10 +2090,8 @@ def sample_generate_content(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.GenerateContentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.GenerateContentRequest): request = prediction_service.GenerateContentRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2205,8 +2206,8 @@ def sample_stream_generate_content(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -2214,10 +2215,8 @@ def sample_stream_generate_content(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.GenerateContentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.GenerateContentRequest): request = prediction_service.GenerateContentRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py index dd4aed9281..04504a821a 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py @@ -55,7 +55,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -75,14 +75,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -92,11 +95,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -122,7 +125,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -163,7 +166,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py index 9441109a90..30d3401bcd 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -70,7 +72,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -100,7 +101,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -120,15 +121,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -138,11 +142,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -168,7 +172,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -208,7 +212,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -643,6 +649,76 @@ def stream_generate_content( ) return self._stubs["stream_generate_content"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.predict: gapic_v1.method_async.wrap_method( + self.predict, + default_timeout=None, + client_info=client_info, + ), + self.raw_predict: gapic_v1.method_async.wrap_method( + self.raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.stream_raw_predict: gapic_v1.method_async.wrap_method( + self.stream_raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.direct_predict: gapic_v1.method_async.wrap_method( + self.direct_predict, + default_timeout=None, + client_info=client_info, + ), + self.direct_raw_predict: gapic_v1.method_async.wrap_method( + self.direct_raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.stream_direct_predict: gapic_v1.method_async.wrap_method( + self.stream_direct_predict, + default_timeout=None, + client_info=client_info, + ), + self.stream_direct_raw_predict: gapic_v1.method_async.wrap_method( + self.stream_direct_raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.streaming_predict: gapic_v1.method_async.wrap_method( + self.streaming_predict, + default_timeout=None, + client_info=client_info, + ), + self.server_streaming_predict: gapic_v1.method_async.wrap_method( + self.server_streaming_predict, + default_timeout=None, + client_info=client_info, + ), + self.streaming_raw_predict: gapic_v1.method_async.wrap_method( + self.streaming_raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.explain: gapic_v1.method_async.wrap_method( + self.explain, + default_timeout=None, + client_info=client_info, + ), + self.generate_content: gapic_v1.method_async.wrap_method( + self.generate_content, + default_timeout=None, + client_info=client_info, + ), + self.stream_generate_content: gapic_v1.method_async.wrap_method( + self.stream_generate_content, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py index 8f7c9cef5d..2e639782e9 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -223,7 +224,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ScheduleServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, ScheduleServiceTransport, Callable[..., ScheduleServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -235,9 +240,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ScheduleServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ScheduleServiceTransport,Callable[..., ScheduleServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ScheduleServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -361,8 +368,8 @@ async def sample_create_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, schedule]) if request is not None and has_flattened_params: raise ValueError( @@ -370,7 +377,10 @@ async def sample_create_schedule(): "the individual field arguments should be set." ) - request = schedule_service.CreateScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.CreateScheduleRequest): + request = schedule_service.CreateScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -381,11 +391,9 @@ async def sample_create_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -483,8 +491,8 @@ async def sample_delete_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -492,7 +500,10 @@ async def sample_delete_schedule(): "the individual field arguments should be set." ) - request = schedule_service.DeleteScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.DeleteScheduleRequest): + request = schedule_service.DeleteScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -501,11 +512,9 @@ async def sample_delete_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -598,8 +607,8 @@ async def sample_get_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -607,7 +616,10 @@ async def sample_get_schedule(): "the individual field arguments should be set." ) - request = schedule_service.GetScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.GetScheduleRequest): + request = schedule_service.GetScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -616,11 +628,9 @@ async def sample_get_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -708,8 +718,8 @@ async def sample_list_schedules(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -717,7 +727,10 @@ async def sample_list_schedules(): "the individual field arguments should be set." ) - request = schedule_service.ListSchedulesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.ListSchedulesRequest): + request = schedule_service.ListSchedulesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -726,11 +739,9 @@ async def sample_list_schedules(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_schedules, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_schedules + ] # Certain fields should be provided within the metadata header; # add these here. @@ -817,8 +828,8 @@ async def sample_pause_schedule(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -826,7 +837,10 @@ async def sample_pause_schedule(): "the individual field arguments should be set." ) - request = schedule_service.PauseScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.PauseScheduleRequest): + request = schedule_service.PauseScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -835,11 +849,9 @@ async def sample_pause_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.pause_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.pause_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -932,8 +944,8 @@ async def sample_resume_schedule(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, catch_up]) if request is not None and has_flattened_params: raise ValueError( @@ -941,7 +953,10 @@ async def sample_resume_schedule(): "the individual field arguments should be set." ) - request = schedule_service.ResumeScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.ResumeScheduleRequest): + request = schedule_service.ResumeScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -952,11 +967,9 @@ async def sample_resume_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.resume_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.resume_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1062,8 +1075,8 @@ async def sample_update_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([schedule, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1071,7 +1084,10 @@ async def sample_update_schedule(): "the individual field arguments should be set." ) - request = schedule_service.UpdateScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.UpdateScheduleRequest): + request = schedule_service.UpdateScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1082,11 +1098,9 @@ async def sample_update_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_schedule + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/schedule_service/client.py b/google/cloud/aiplatform_v1/services/schedule_service/client.py index 6446ad1c12..bfcd868901 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/client.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -679,7 +680,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ScheduleServiceTransport]] = None, + transport: Optional[ + Union[ + str, ScheduleServiceTransport, Callable[..., ScheduleServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -691,9 +696,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ScheduleServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ScheduleServiceTransport,Callable[..., ScheduleServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ScheduleServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -805,8 +812,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ScheduleServiceTransport], Callable[..., ScheduleServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ScheduleServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -895,8 +909,8 @@ def sample_create_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, schedule]) if request is not None and has_flattened_params: raise ValueError( @@ -904,10 +918,8 @@ def sample_create_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.CreateScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.CreateScheduleRequest): request = schedule_service.CreateScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1017,8 +1029,8 @@ def sample_delete_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1026,10 +1038,8 @@ def sample_delete_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.DeleteScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.DeleteScheduleRequest): request = schedule_service.DeleteScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1132,8 +1142,8 @@ def sample_get_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1141,10 +1151,8 @@ def sample_get_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.GetScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.GetScheduleRequest): request = schedule_service.GetScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1242,8 +1250,8 @@ def sample_list_schedules(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1251,10 +1259,8 @@ def sample_list_schedules(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.ListSchedulesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.ListSchedulesRequest): request = schedule_service.ListSchedulesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1351,8 +1357,8 @@ def sample_pause_schedule(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1360,10 +1366,8 @@ def sample_pause_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.PauseScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.PauseScheduleRequest): request = schedule_service.PauseScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1466,8 +1470,8 @@ def sample_resume_schedule(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, catch_up]) if request is not None and has_flattened_params: raise ValueError( @@ -1475,10 +1479,8 @@ def sample_resume_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.ResumeScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.ResumeScheduleRequest): request = schedule_service.ResumeScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1596,8 +1598,8 @@ def sample_update_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([schedule, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1605,10 +1607,8 @@ def sample_update_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.UpdateScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.UpdateScheduleRequest): request = schedule_service.UpdateScheduleRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc.py index 1a5252b10e..2d310099e0 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc.py @@ -60,7 +60,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -80,14 +80,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -97,11 +100,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -128,7 +131,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -169,7 +172,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc_asyncio.py index d63b10b892..41e0956b1c 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -75,7 +77,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -105,7 +106,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -125,15 +126,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -143,11 +147,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -174,7 +178,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -214,7 +218,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -467,6 +473,46 @@ def update_schedule( ) return self._stubs["update_schedule"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_schedule: gapic_v1.method_async.wrap_method( + self.create_schedule, + default_timeout=None, + client_info=client_info, + ), + self.delete_schedule: gapic_v1.method_async.wrap_method( + self.delete_schedule, + default_timeout=None, + client_info=client_info, + ), + self.get_schedule: gapic_v1.method_async.wrap_method( + self.get_schedule, + default_timeout=None, + client_info=client_info, + ), + self.list_schedules: gapic_v1.method_async.wrap_method( + self.list_schedules, + default_timeout=None, + client_info=client_info, + ), + self.pause_schedule: gapic_v1.method_async.wrap_method( + self.pause_schedule, + default_timeout=None, + client_info=client_info, + ), + self.resume_schedule: gapic_v1.method_async.wrap_method( + self.resume_schedule, + default_timeout=None, + client_info=client_info, + ), + self.update_schedule: gapic_v1.method_async.wrap_method( + self.update_schedule, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py index 4503d031dc..809211b5a2 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -217,7 +218,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + SpecialistPoolServiceTransport, + Callable[..., SpecialistPoolServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -229,9 +236,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.SpecialistPoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,SpecialistPoolServiceTransport,Callable[..., SpecialistPoolServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the SpecialistPoolServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -367,8 +376,8 @@ async def sample_create_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: raise ValueError( @@ -376,7 +385,10 @@ async def sample_create_specialist_pool(): "the individual field arguments should be set." ) - request = specialist_pool_service.CreateSpecialistPoolRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.CreateSpecialistPoolRequest): + request = specialist_pool_service.CreateSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -387,11 +399,9 @@ async def sample_create_specialist_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_specialist_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_specialist_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -495,8 +505,8 @@ async def sample_get_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -504,7 +514,10 @@ async def sample_get_specialist_pool(): "the individual field arguments should be set." ) - request = specialist_pool_service.GetSpecialistPoolRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.GetSpecialistPoolRequest): + request = specialist_pool_service.GetSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -513,11 +526,9 @@ async def sample_get_specialist_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_specialist_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_specialist_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -607,8 +618,8 @@ async def sample_list_specialist_pools(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -616,7 +627,10 @@ async def sample_list_specialist_pools(): "the individual field arguments should be set." ) - request = specialist_pool_service.ListSpecialistPoolsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.ListSpecialistPoolsRequest): + request = specialist_pool_service.ListSpecialistPoolsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -625,11 +639,9 @@ async def sample_list_specialist_pools(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_specialist_pools, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_specialist_pools + ] # Certain fields should be provided within the metadata header; # add these here. @@ -739,8 +751,8 @@ async def sample_delete_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -748,7 +760,10 @@ async def sample_delete_specialist_pool(): "the individual field arguments should be set." ) - request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.DeleteSpecialistPoolRequest): + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -757,11 +772,9 @@ async def sample_delete_specialist_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_specialist_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_specialist_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -878,8 +891,8 @@ async def sample_update_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -887,7 +900,10 @@ async def sample_update_specialist_pool(): "the individual field arguments should be set." ) - request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.UpdateSpecialistPoolRequest): + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -898,11 +914,9 @@ async def sample_update_specialist_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_specialist_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_specialist_pool + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py index fd66672433..ffed428ae2 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -547,7 +548,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, SpecialistPoolServiceTransport]] = None, + transport: Optional[ + Union[ + str, + SpecialistPoolServiceTransport, + Callable[..., SpecialistPoolServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -559,9 +566,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, SpecialistPoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,SpecialistPoolServiceTransport,Callable[..., SpecialistPoolServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the SpecialistPoolServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -673,8 +682,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[SpecialistPoolServiceTransport], + Callable[..., SpecialistPoolServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., SpecialistPoolServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -775,8 +792,8 @@ def sample_create_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: raise ValueError( @@ -784,10 +801,8 @@ def sample_create_specialist_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.CreateSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.CreateSpecialistPoolRequest): request = specialist_pool_service.CreateSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the @@ -903,8 +918,8 @@ def sample_get_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -912,10 +927,8 @@ def sample_get_specialist_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.GetSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.GetSpecialistPoolRequest): request = specialist_pool_service.GetSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1015,8 +1028,8 @@ def sample_list_specialist_pools(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1024,10 +1037,8 @@ def sample_list_specialist_pools(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.ListSpecialistPoolsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.ListSpecialistPoolsRequest): request = specialist_pool_service.ListSpecialistPoolsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1147,8 +1158,8 @@ def sample_delete_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1156,10 +1167,8 @@ def sample_delete_specialist_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.DeleteSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.DeleteSpecialistPoolRequest): request = specialist_pool_service.DeleteSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1286,8 +1295,8 @@ def sample_update_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1295,10 +1304,8 @@ def sample_update_specialist_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.UpdateSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.UpdateSpecialistPoolRequest): request = specialist_pool_service.UpdateSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py index e9922bf474..42095120cb 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py @@ -61,7 +61,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -81,14 +81,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -98,11 +101,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -129,7 +132,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -170,7 +173,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py index 6024a4f302..10d501611e 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -76,7 +78,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -106,7 +107,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -126,15 +127,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -144,11 +148,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -175,7 +179,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -215,7 +219,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -406,6 +412,36 @@ def update_specialist_pool( ) return self._stubs["update_specialist_pool"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_specialist_pool: gapic_v1.method_async.wrap_method( + self.create_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.get_specialist_pool: gapic_v1.method_async.wrap_method( + self.get_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.list_specialist_pools: gapic_v1.method_async.wrap_method( + self.list_specialist_pools, + default_timeout=None, + client_info=client_info, + ), + self.delete_specialist_pool: gapic_v1.method_async.wrap_method( + self.delete_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.update_specialist_pool: gapic_v1.method_async.wrap_method( + self.update_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py b/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py index 6ec1ab8b51..8647c749ff 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -238,7 +239,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, TensorboardServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + TensorboardServiceTransport, + Callable[..., TensorboardServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -250,9 +257,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.TensorboardServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,TensorboardServiceTransport,Callable[..., TensorboardServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the TensorboardServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -381,8 +390,8 @@ async def sample_create_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -390,7 +399,10 @@ async def sample_create_tensorboard(): "the individual field arguments should be set." ) - request = tensorboard_service.CreateTensorboardRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.CreateTensorboardRequest): + request = tensorboard_service.CreateTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -401,11 +413,9 @@ async def sample_create_tensorboard(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tensorboard, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tensorboard + ] # Certain fields should be provided within the metadata header; # add these here. @@ -502,8 +512,8 @@ async def sample_get_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -511,7 +521,10 @@ async def sample_get_tensorboard(): "the individual field arguments should be set." ) - request = tensorboard_service.GetTensorboardRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.GetTensorboardRequest): + request = tensorboard_service.GetTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -520,11 +533,9 @@ async def sample_get_tensorboard(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tensorboard, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tensorboard + ] # Certain fields should be provided within the metadata header; # add these here. @@ -634,8 +645,8 @@ async def sample_update_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -643,7 +654,10 @@ async def sample_update_tensorboard(): "the individual field arguments should be set." ) - request = tensorboard_service.UpdateTensorboardRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.UpdateTensorboardRequest): + request = tensorboard_service.UpdateTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -654,11 +668,9 @@ async def sample_update_tensorboard(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_tensorboard, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tensorboard + ] # Certain fields should be provided within the metadata header; # add these here. @@ -758,8 +770,8 @@ async def sample_list_tensorboards(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -767,7 +779,10 @@ async def sample_list_tensorboards(): "the individual field arguments should be set." ) - request = tensorboard_service.ListTensorboardsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ListTensorboardsRequest): + request = tensorboard_service.ListTensorboardsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -776,11 +791,9 @@ async def sample_list_tensorboards(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tensorboards, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tensorboards + ] # Certain fields should be provided within the metadata header; # add these here. @@ -889,8 +902,8 @@ async def sample_delete_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -898,7 +911,10 @@ async def sample_delete_tensorboard(): "the individual field arguments should be set." ) - request = tensorboard_service.DeleteTensorboardRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.DeleteTensorboardRequest): + request = tensorboard_service.DeleteTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -907,11 +923,9 @@ async def sample_delete_tensorboard(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_tensorboard, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tensorboard + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1005,8 +1019,8 @@ async def sample_read_tensorboard_usage(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -1014,7 +1028,10 @@ async def sample_read_tensorboard_usage(): "the individual field arguments should be set." ) - request = tensorboard_service.ReadTensorboardUsageRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ReadTensorboardUsageRequest): + request = tensorboard_service.ReadTensorboardUsageRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1023,11 +1040,9 @@ async def sample_read_tensorboard_usage(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_tensorboard_usage, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_tensorboard_usage + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1115,8 +1130,8 @@ async def sample_read_tensorboard_size(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -1124,7 +1139,10 @@ async def sample_read_tensorboard_size(): "the individual field arguments should be set." ) - request = tensorboard_service.ReadTensorboardSizeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ReadTensorboardSizeRequest): + request = tensorboard_service.ReadTensorboardSizeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1133,11 +1151,9 @@ async def sample_read_tensorboard_size(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_tensorboard_size, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_tensorboard_size + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1248,8 +1264,8 @@ async def sample_create_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, tensorboard_experiment, tensorboard_experiment_id] ) @@ -1259,7 +1275,12 @@ async def sample_create_tensorboard_experiment(): "the individual field arguments should be set." ) - request = tensorboard_service.CreateTensorboardExperimentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.CreateTensorboardExperimentRequest + ): + request = tensorboard_service.CreateTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1272,11 +1293,9 @@ async def sample_create_tensorboard_experiment(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tensorboard_experiment, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tensorboard_experiment + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1364,8 +1383,8 @@ async def sample_get_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1373,7 +1392,10 @@ async def sample_get_tensorboard_experiment(): "the individual field arguments should be set." ) - request = tensorboard_service.GetTensorboardExperimentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.GetTensorboardExperimentRequest): + request = tensorboard_service.GetTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1382,11 +1404,9 @@ async def sample_get_tensorboard_experiment(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tensorboard_experiment, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tensorboard_experiment + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1489,8 +1509,8 @@ async def sample_update_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1498,7 +1518,12 @@ async def sample_update_tensorboard_experiment(): "the individual field arguments should be set." ) - request = tensorboard_service.UpdateTensorboardExperimentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.UpdateTensorboardExperimentRequest + ): + request = tensorboard_service.UpdateTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1509,11 +1534,9 @@ async def sample_update_tensorboard_experiment(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_tensorboard_experiment, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tensorboard_experiment + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1605,8 +1628,8 @@ async def sample_list_tensorboard_experiments(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1614,7 +1637,12 @@ async def sample_list_tensorboard_experiments(): "the individual field arguments should be set." ) - request = tensorboard_service.ListTensorboardExperimentsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.ListTensorboardExperimentsRequest + ): + request = tensorboard_service.ListTensorboardExperimentsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1623,11 +1651,9 @@ async def sample_list_tensorboard_experiments(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tensorboard_experiments, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tensorboard_experiments + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1736,8 +1762,8 @@ async def sample_delete_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1745,7 +1771,12 @@ async def sample_delete_tensorboard_experiment(): "the individual field arguments should be set." ) - request = tensorboard_service.DeleteTensorboardExperimentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.DeleteTensorboardExperimentRequest + ): + request = tensorboard_service.DeleteTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1754,11 +1785,9 @@ async def sample_delete_tensorboard_experiment(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_tensorboard_experiment, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tensorboard_experiment + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1879,8 +1908,8 @@ async def sample_create_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1888,7 +1917,10 @@ async def sample_create_tensorboard_run(): "the individual field arguments should be set." ) - request = tensorboard_service.CreateTensorboardRunRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.CreateTensorboardRunRequest): + request = tensorboard_service.CreateTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1901,11 +1933,9 @@ async def sample_create_tensorboard_run(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tensorboard_run, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tensorboard_run + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2012,8 +2042,8 @@ async def sample_batch_create_tensorboard_runs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -2021,7 +2051,12 @@ async def sample_batch_create_tensorboard_runs(): "the individual field arguments should be set." ) - request = tensorboard_service.BatchCreateTensorboardRunsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.BatchCreateTensorboardRunsRequest + ): + request = tensorboard_service.BatchCreateTensorboardRunsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2032,11 +2067,9 @@ async def sample_batch_create_tensorboard_runs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_create_tensorboard_runs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_create_tensorboard_runs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2124,8 +2157,8 @@ async def sample_get_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2133,7 +2166,10 @@ async def sample_get_tensorboard_run(): "the individual field arguments should be set." ) - request = tensorboard_service.GetTensorboardRunRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.GetTensorboardRunRequest): + request = tensorboard_service.GetTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2142,11 +2178,9 @@ async def sample_get_tensorboard_run(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tensorboard_run, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tensorboard_run + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2250,8 +2284,8 @@ async def sample_update_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2259,7 +2293,10 @@ async def sample_update_tensorboard_run(): "the individual field arguments should be set." ) - request = tensorboard_service.UpdateTensorboardRunRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.UpdateTensorboardRunRequest): + request = tensorboard_service.UpdateTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2270,11 +2307,9 @@ async def sample_update_tensorboard_run(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_tensorboard_run, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tensorboard_run + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2366,8 +2401,8 @@ async def sample_list_tensorboard_runs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2375,7 +2410,10 @@ async def sample_list_tensorboard_runs(): "the individual field arguments should be set." ) - request = tensorboard_service.ListTensorboardRunsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ListTensorboardRunsRequest): + request = tensorboard_service.ListTensorboardRunsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2384,11 +2422,9 @@ async def sample_list_tensorboard_runs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tensorboard_runs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tensorboard_runs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2497,8 +2533,8 @@ async def sample_delete_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2506,7 +2542,10 @@ async def sample_delete_tensorboard_run(): "the individual field arguments should be set." ) - request = tensorboard_service.DeleteTensorboardRunRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.DeleteTensorboardRunRequest): + request = tensorboard_service.DeleteTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2515,11 +2554,9 @@ async def sample_delete_tensorboard_run(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_tensorboard_run, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tensorboard_run + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2636,8 +2673,8 @@ async def sample_batch_create_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -2645,7 +2682,14 @@ async def sample_batch_create_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.BatchCreateTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.BatchCreateTensorboardTimeSeriesRequest + ): + request = tensorboard_service.BatchCreateTensorboardTimeSeriesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2656,11 +2700,9 @@ async def sample_batch_create_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_create_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_create_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2761,8 +2803,8 @@ async def sample_create_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -2770,7 +2812,12 @@ async def sample_create_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.CreateTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.CreateTensorboardTimeSeriesRequest + ): + request = tensorboard_service.CreateTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2781,11 +2828,9 @@ async def sample_create_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2871,8 +2916,8 @@ async def sample_get_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2880,7 +2925,10 @@ async def sample_get_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.GetTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.GetTensorboardTimeSeriesRequest): + request = tensorboard_service.GetTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2889,11 +2937,9 @@ async def sample_get_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2999,8 +3045,8 @@ async def sample_update_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -3008,7 +3054,12 @@ async def sample_update_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.UpdateTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.UpdateTensorboardTimeSeriesRequest + ): + request = tensorboard_service.UpdateTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3019,11 +3070,9 @@ async def sample_update_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3120,8 +3169,8 @@ async def sample_list_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3129,7 +3178,12 @@ async def sample_list_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.ListTensorboardTimeSeriesRequest + ): + request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3138,11 +3192,9 @@ async def sample_list_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3251,8 +3303,8 @@ async def sample_delete_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3260,7 +3312,12 @@ async def sample_delete_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.DeleteTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.DeleteTensorboardTimeSeriesRequest + ): + request = tensorboard_service.DeleteTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3269,11 +3326,9 @@ async def sample_delete_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3377,8 +3432,8 @@ async def sample_batch_read_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -3386,7 +3441,14 @@ async def sample_batch_read_tensorboard_time_series_data(): "the individual field arguments should be set." ) - request = tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest + ): + request = tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3395,11 +3457,9 @@ async def sample_batch_read_tensorboard_time_series_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_read_tensorboard_time_series_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_read_tensorboard_time_series_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3491,8 +3551,8 @@ async def sample_read_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -3500,7 +3560,12 @@ async def sample_read_tensorboard_time_series_data(): "the individual field arguments should be set." ) - request = tensorboard_service.ReadTensorboardTimeSeriesDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.ReadTensorboardTimeSeriesDataRequest + ): + request = tensorboard_service.ReadTensorboardTimeSeriesDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3509,11 +3574,9 @@ async def sample_read_tensorboard_time_series_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_tensorboard_time_series_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_tensorboard_time_series_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3605,8 +3668,8 @@ async def sample_read_tensorboard_blob_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -3614,7 +3677,10 @@ async def sample_read_tensorboard_blob_data(): "the individual field arguments should be set." ) - request = tensorboard_service.ReadTensorboardBlobDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ReadTensorboardBlobDataRequest): + request = tensorboard_service.ReadTensorboardBlobDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3623,11 +3689,9 @@ async def sample_read_tensorboard_blob_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_tensorboard_blob_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_tensorboard_blob_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3733,8 +3797,8 @@ async def sample_write_tensorboard_experiment_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, write_run_data_requests]) if request is not None and has_flattened_params: raise ValueError( @@ -3742,7 +3806,12 @@ async def sample_write_tensorboard_experiment_data(): "the individual field arguments should be set." ) - request = tensorboard_service.WriteTensorboardExperimentDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.WriteTensorboardExperimentDataRequest + ): + request = tensorboard_service.WriteTensorboardExperimentDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3753,11 +3822,9 @@ async def sample_write_tensorboard_experiment_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.write_tensorboard_experiment_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.write_tensorboard_experiment_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3868,8 +3935,8 @@ async def sample_write_tensorboard_run_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, time_series_data]) if request is not None and has_flattened_params: raise ValueError( @@ -3877,7 +3944,10 @@ async def sample_write_tensorboard_run_data(): "the individual field arguments should be set." ) - request = tensorboard_service.WriteTensorboardRunDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.WriteTensorboardRunDataRequest): + request = tensorboard_service.WriteTensorboardRunDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3888,11 +3958,9 @@ async def sample_write_tensorboard_run_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.write_tensorboard_run_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.write_tensorboard_run_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3985,8 +4053,8 @@ async def sample_export_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -3994,7 +4062,14 @@ async def sample_export_tensorboard_time_series_data(): "the individual field arguments should be set." ) - request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.ExportTensorboardTimeSeriesDataRequest + ): + request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4003,11 +4078,9 @@ async def sample_export_tensorboard_time_series_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_tensorboard_time_series_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_tensorboard_time_series_data + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1/services/tensorboard_service/client.py index 85107fb255..062416ecd4 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/client.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -633,7 +634,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, TensorboardServiceTransport]] = None, + transport: Optional[ + Union[ + str, + TensorboardServiceTransport, + Callable[..., TensorboardServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -645,9 +652,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, TensorboardServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,TensorboardServiceTransport,Callable[..., TensorboardServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the TensorboardServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -759,8 +768,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[TensorboardServiceTransport], + Callable[..., TensorboardServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., TensorboardServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -854,8 +871,8 @@ def sample_create_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -863,10 +880,8 @@ def sample_create_tensorboard(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.CreateTensorboardRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.CreateTensorboardRequest): request = tensorboard_service.CreateTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the @@ -975,8 +990,8 @@ def sample_get_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -984,10 +999,8 @@ def sample_get_tensorboard(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.GetTensorboardRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.GetTensorboardRequest): request = tensorboard_service.GetTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1107,8 +1120,8 @@ def sample_update_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1116,10 +1129,8 @@ def sample_update_tensorboard(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.UpdateTensorboardRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.UpdateTensorboardRequest): request = tensorboard_service.UpdateTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1231,8 +1242,8 @@ def sample_list_tensorboards(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1240,10 +1251,8 @@ def sample_list_tensorboards(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ListTensorboardsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ListTensorboardsRequest): request = tensorboard_service.ListTensorboardsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1362,8 +1371,8 @@ def sample_delete_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1371,10 +1380,8 @@ def sample_delete_tensorboard(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.DeleteTensorboardRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.DeleteTensorboardRequest): request = tensorboard_service.DeleteTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1478,8 +1485,8 @@ def sample_read_tensorboard_usage(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -1487,10 +1494,8 @@ def sample_read_tensorboard_usage(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ReadTensorboardUsageRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ReadTensorboardUsageRequest): request = tensorboard_service.ReadTensorboardUsageRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1588,8 +1593,8 @@ def sample_read_tensorboard_size(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -1597,10 +1602,8 @@ def sample_read_tensorboard_size(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ReadTensorboardSizeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ReadTensorboardSizeRequest): request = tensorboard_service.ReadTensorboardSizeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1721,8 +1724,8 @@ def sample_create_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, tensorboard_experiment, tensorboard_experiment_id] ) @@ -1732,10 +1735,8 @@ def sample_create_tensorboard_experiment(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.CreateTensorboardExperimentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.CreateTensorboardExperimentRequest ): @@ -1841,8 +1842,8 @@ def sample_get_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1850,10 +1851,8 @@ def sample_get_tensorboard_experiment(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.GetTensorboardExperimentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.GetTensorboardExperimentRequest): request = tensorboard_service.GetTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1968,8 +1967,8 @@ def sample_update_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1977,10 +1976,8 @@ def sample_update_tensorboard_experiment(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.UpdateTensorboardExperimentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.UpdateTensorboardExperimentRequest ): @@ -2088,8 +2085,8 @@ def sample_list_tensorboard_experiments(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2097,10 +2094,8 @@ def sample_list_tensorboard_experiments(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ListTensorboardExperimentsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.ListTensorboardExperimentsRequest ): @@ -2223,8 +2218,8 @@ def sample_delete_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2232,10 +2227,8 @@ def sample_delete_tensorboard_experiment(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.DeleteTensorboardExperimentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.DeleteTensorboardExperimentRequest ): @@ -2370,8 +2363,8 @@ def sample_create_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) if request is not None and has_flattened_params: raise ValueError( @@ -2379,10 +2372,8 @@ def sample_create_tensorboard_run(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.CreateTensorboardRunRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.CreateTensorboardRunRequest): request = tensorboard_service.CreateTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2503,8 +2494,8 @@ def sample_batch_create_tensorboard_runs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -2512,10 +2503,8 @@ def sample_batch_create_tensorboard_runs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.BatchCreateTensorboardRunsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.BatchCreateTensorboardRunsRequest ): @@ -2619,8 +2608,8 @@ def sample_get_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2628,10 +2617,8 @@ def sample_get_tensorboard_run(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.GetTensorboardRunRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.GetTensorboardRunRequest): request = tensorboard_service.GetTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2745,8 +2732,8 @@ def sample_update_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2754,10 +2741,8 @@ def sample_update_tensorboard_run(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.UpdateTensorboardRunRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.UpdateTensorboardRunRequest): request = tensorboard_service.UpdateTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2861,8 +2846,8 @@ def sample_list_tensorboard_runs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2870,10 +2855,8 @@ def sample_list_tensorboard_runs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ListTensorboardRunsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ListTensorboardRunsRequest): request = tensorboard_service.ListTensorboardRunsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2992,8 +2975,8 @@ def sample_delete_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3001,10 +2984,8 @@ def sample_delete_tensorboard_run(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.DeleteTensorboardRunRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.DeleteTensorboardRunRequest): request = tensorboard_service.DeleteTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3131,8 +3112,8 @@ def sample_batch_create_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -3140,10 +3121,8 @@ def sample_batch_create_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.BatchCreateTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.BatchCreateTensorboardTimeSeriesRequest ): @@ -3262,8 +3241,8 @@ def sample_create_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -3271,10 +3250,8 @@ def sample_create_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.CreateTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.CreateTensorboardTimeSeriesRequest ): @@ -3376,8 +3353,8 @@ def sample_get_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3385,10 +3362,8 @@ def sample_get_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.GetTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.GetTensorboardTimeSeriesRequest): request = tensorboard_service.GetTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3506,8 +3481,8 @@ def sample_update_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -3515,10 +3490,8 @@ def sample_update_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.UpdateTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.UpdateTensorboardTimeSeriesRequest ): @@ -3631,8 +3604,8 @@ def sample_list_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3640,10 +3613,8 @@ def sample_list_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ListTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.ListTensorboardTimeSeriesRequest ): @@ -3766,8 +3737,8 @@ def sample_delete_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3775,10 +3746,8 @@ def sample_delete_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.DeleteTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.DeleteTensorboardTimeSeriesRequest ): @@ -3896,8 +3865,8 @@ def sample_batch_read_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -3905,10 +3874,8 @@ def sample_batch_read_tensorboard_time_series_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest ): @@ -4016,8 +3983,8 @@ def sample_read_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -4025,10 +3992,8 @@ def sample_read_tensorboard_time_series_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ReadTensorboardTimeSeriesDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.ReadTensorboardTimeSeriesDataRequest ): @@ -4134,8 +4099,8 @@ def sample_read_tensorboard_blob_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -4143,10 +4108,8 @@ def sample_read_tensorboard_blob_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ReadTensorboardBlobDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ReadTensorboardBlobDataRequest): request = tensorboard_service.ReadTensorboardBlobDataRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4264,8 +4227,8 @@ def sample_write_tensorboard_experiment_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, write_run_data_requests]) if request is not None and has_flattened_params: raise ValueError( @@ -4273,10 +4236,8 @@ def sample_write_tensorboard_experiment_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.WriteTensorboardExperimentDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.WriteTensorboardExperimentDataRequest ): @@ -4403,8 +4364,8 @@ def sample_write_tensorboard_run_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, time_series_data]) if request is not None and has_flattened_params: raise ValueError( @@ -4412,10 +4373,8 @@ def sample_write_tensorboard_run_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.WriteTensorboardRunDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.WriteTensorboardRunDataRequest): request = tensorboard_service.WriteTensorboardRunDataRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4522,8 +4481,8 @@ def sample_export_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -4531,10 +4490,8 @@ def sample_export_tensorboard_time_series_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ExportTensorboardTimeSeriesDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.ExportTensorboardTimeSeriesDataRequest ): diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py index d79013ef79..51ac956b85 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc.py @@ -66,7 +66,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -86,14 +86,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -103,11 +106,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -134,7 +137,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -175,7 +178,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py index b6055d4568..f44b14b5fc 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -81,7 +83,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -111,7 +112,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -131,15 +132,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -149,11 +153,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -180,7 +184,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -220,7 +224,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -1183,6 +1189,161 @@ def export_tensorboard_time_series_data( ) return self._stubs["export_tensorboard_time_series_data"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_tensorboard: gapic_v1.method_async.wrap_method( + self.create_tensorboard, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard: gapic_v1.method_async.wrap_method( + self.get_tensorboard, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard: gapic_v1.method_async.wrap_method( + self.update_tensorboard, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboards: gapic_v1.method_async.wrap_method( + self.list_tensorboards, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard: gapic_v1.method_async.wrap_method( + self.delete_tensorboard, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_usage: gapic_v1.method_async.wrap_method( + self.read_tensorboard_usage, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_size: gapic_v1.method_async.wrap_method( + self.read_tensorboard_size, + default_timeout=None, + client_info=client_info, + ), + self.create_tensorboard_experiment: gapic_v1.method_async.wrap_method( + self.create_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_experiment: gapic_v1.method_async.wrap_method( + self.get_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard_experiment: gapic_v1.method_async.wrap_method( + self.update_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_experiments: gapic_v1.method_async.wrap_method( + self.list_tensorboard_experiments, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_experiment: gapic_v1.method_async.wrap_method( + self.delete_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.create_tensorboard_run: gapic_v1.method_async.wrap_method( + self.create_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.batch_create_tensorboard_runs: gapic_v1.method_async.wrap_method( + self.batch_create_tensorboard_runs, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_run: gapic_v1.method_async.wrap_method( + self.get_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard_run: gapic_v1.method_async.wrap_method( + self.update_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_runs: gapic_v1.method_async.wrap_method( + self.list_tensorboard_runs, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_run: gapic_v1.method_async.wrap_method( + self.delete_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.batch_create_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.batch_create_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.create_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.create_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.get_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.update_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.list_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.delete_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.batch_read_tensorboard_time_series_data: gapic_v1.method_async.wrap_method( + self.batch_read_tensorboard_time_series_data, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_time_series_data: gapic_v1.method_async.wrap_method( + self.read_tensorboard_time_series_data, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_blob_data: gapic_v1.method_async.wrap_method( + self.read_tensorboard_blob_data, + default_timeout=None, + client_info=client_info, + ), + self.write_tensorboard_experiment_data: gapic_v1.method_async.wrap_method( + self.write_tensorboard_experiment_data, + default_timeout=None, + client_info=client_info, + ), + self.write_tensorboard_run_data: gapic_v1.method_async.wrap_method( + self.write_tensorboard_run_data, + default_timeout=None, + client_info=client_info, + ), + self.export_tensorboard_time_series_data: gapic_v1.method_async.wrap_method( + self.export_tensorboard_time_series_data, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1/services/vizier_service/async_client.py index 2a7937226d..38219516a0 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -211,7 +212,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, VizierServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, VizierServiceTransport, Callable[..., VizierServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -223,9 +226,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.VizierServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,VizierServiceTransport,Callable[..., VizierServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the VizierServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -350,8 +355,8 @@ async def sample_create_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: raise ValueError( @@ -359,7 +364,10 @@ async def sample_create_study(): "the individual field arguments should be set." ) - request = vizier_service.CreateStudyRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.CreateStudyRequest): + request = vizier_service.CreateStudyRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -370,11 +378,9 @@ async def sample_create_study(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_study, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_study + ] # Certain fields should be provided within the metadata header; # add these here. @@ -455,8 +461,8 @@ async def sample_get_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -464,7 +470,10 @@ async def sample_get_study(): "the individual field arguments should be set." ) - request = vizier_service.GetStudyRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.GetStudyRequest): + request = vizier_service.GetStudyRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -473,11 +482,9 @@ async def sample_get_study(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_study, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_study + ] # Certain fields should be provided within the metadata header; # add these here. @@ -566,8 +573,8 @@ async def sample_list_studies(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -575,7 +582,10 @@ async def sample_list_studies(): "the individual field arguments should be set." ) - request = vizier_service.ListStudiesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.ListStudiesRequest): + request = vizier_service.ListStudiesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -584,11 +594,9 @@ async def sample_list_studies(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_studies, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_studies + ] # Certain fields should be provided within the metadata header; # add these here. @@ -672,8 +680,8 @@ async def sample_delete_study(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -681,7 +689,10 @@ async def sample_delete_study(): "the individual field arguments should be set." ) - request = vizier_service.DeleteStudyRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.DeleteStudyRequest): + request = vizier_service.DeleteStudyRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -690,11 +701,9 @@ async def sample_delete_study(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_study, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_study + ] # Certain fields should be provided within the metadata header; # add these here. @@ -775,8 +784,8 @@ async def sample_lookup_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -784,7 +793,10 @@ async def sample_lookup_study(): "the individual field arguments should be set." ) - request = vizier_service.LookupStudyRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.LookupStudyRequest): + request = vizier_service.LookupStudyRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -793,11 +805,9 @@ async def sample_lookup_study(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.lookup_study, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.lookup_study + ] # Certain fields should be provided within the metadata header; # add these here. @@ -884,15 +894,16 @@ async def sample_suggest_trials(): """ # Create or coerce a protobuf request object. - request = vizier_service.SuggestTrialsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.SuggestTrialsRequest): + request = vizier_service.SuggestTrialsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.suggest_trials, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.suggest_trials + ] # Certain fields should be provided within the metadata header; # add these here. @@ -993,8 +1004,8 @@ async def sample_create_trial(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: raise ValueError( @@ -1002,7 +1013,10 @@ async def sample_create_trial(): "the individual field arguments should be set." ) - request = vizier_service.CreateTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.CreateTrialRequest): + request = vizier_service.CreateTrialRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1013,11 +1027,9 @@ async def sample_create_trial(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_trial, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1103,8 +1115,8 @@ async def sample_get_trial(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1112,7 +1124,10 @@ async def sample_get_trial(): "the individual field arguments should be set." ) - request = vizier_service.GetTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.GetTrialRequest): + request = vizier_service.GetTrialRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1121,11 +1136,9 @@ async def sample_get_trial(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_trial, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1213,8 +1226,8 @@ async def sample_list_trials(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1222,7 +1235,10 @@ async def sample_list_trials(): "the individual field arguments should be set." ) - request = vizier_service.ListTrialsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.ListTrialsRequest): + request = vizier_service.ListTrialsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1231,11 +1247,9 @@ async def sample_list_trials(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_trials, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_trials + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1326,15 +1340,16 @@ async def sample_add_trial_measurement(): """ # Create or coerce a protobuf request object. - request = vizier_service.AddTrialMeasurementRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.AddTrialMeasurementRequest): + request = vizier_service.AddTrialMeasurementRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.add_trial_measurement, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.add_trial_measurement + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1414,15 +1429,16 @@ async def sample_complete_trial(): """ # Create or coerce a protobuf request object. - request = vizier_service.CompleteTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.CompleteTrialRequest): + request = vizier_service.CompleteTrialRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.complete_trial, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.complete_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1496,8 +1512,8 @@ async def sample_delete_trial(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1505,7 +1521,10 @@ async def sample_delete_trial(): "the individual field arguments should be set." ) - request = vizier_service.DeleteTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.DeleteTrialRequest): + request = vizier_service.DeleteTrialRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1514,11 +1533,9 @@ async def sample_delete_trial(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_trial, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1601,15 +1618,16 @@ async def sample_check_trial_early_stopping_state(): """ # Create or coerce a protobuf request object. - request = vizier_service.CheckTrialEarlyStoppingStateRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.CheckTrialEarlyStoppingStateRequest): + request = vizier_service.CheckTrialEarlyStoppingStateRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.check_trial_early_stopping_state, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.check_trial_early_stopping_state + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1697,15 +1715,16 @@ async def sample_stop_trial(): """ # Create or coerce a protobuf request object. - request = vizier_service.StopTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.StopTrialRequest): + request = vizier_service.StopTrialRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stop_trial, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stop_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1791,8 +1810,8 @@ async def sample_list_optimal_trials(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1800,7 +1819,10 @@ async def sample_list_optimal_trials(): "the individual field arguments should be set." ) - request = vizier_service.ListOptimalTrialsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.ListOptimalTrialsRequest): + request = vizier_service.ListOptimalTrialsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1809,11 +1831,9 @@ async def sample_list_optimal_trials(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_optimal_trials, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_optimal_trials + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1/services/vizier_service/client.py b/google/cloud/aiplatform_v1/services/vizier_service/client.py index 29a5613b92..dc15e77f9a 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -586,7 +587,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, VizierServiceTransport]] = None, + transport: Optional[ + Union[str, VizierServiceTransport, Callable[..., VizierServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -598,9 +601,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, VizierServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,VizierServiceTransport,Callable[..., VizierServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the VizierServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -712,8 +717,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[VizierServiceTransport], Callable[..., VizierServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., VizierServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -803,8 +815,8 @@ def sample_create_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: raise ValueError( @@ -812,10 +824,8 @@ def sample_create_study(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.CreateStudyRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.CreateStudyRequest): request = vizier_service.CreateStudyRequest(request) # If we have keyword arguments corresponding to fields on the @@ -908,8 +918,8 @@ def sample_get_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -917,10 +927,8 @@ def sample_get_study(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.GetStudyRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.GetStudyRequest): request = vizier_service.GetStudyRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1019,8 +1027,8 @@ def sample_list_studies(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1028,10 +1036,8 @@ def sample_list_studies(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.ListStudiesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.ListStudiesRequest): request = vizier_service.ListStudiesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1125,8 +1131,8 @@ def sample_delete_study(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1134,10 +1140,8 @@ def sample_delete_study(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.DeleteStudyRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.DeleteStudyRequest): request = vizier_service.DeleteStudyRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1228,8 +1232,8 @@ def sample_lookup_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1237,10 +1241,8 @@ def sample_lookup_study(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.LookupStudyRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.LookupStudyRequest): request = vizier_service.LookupStudyRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1337,10 +1339,8 @@ def sample_suggest_trials(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.SuggestTrialsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.SuggestTrialsRequest): request = vizier_service.SuggestTrialsRequest(request) @@ -1447,8 +1447,8 @@ def sample_create_trial(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: raise ValueError( @@ -1456,10 +1456,8 @@ def sample_create_trial(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.CreateTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.CreateTrialRequest): request = vizier_service.CreateTrialRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1557,8 +1555,8 @@ def sample_get_trial(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1566,10 +1564,8 @@ def sample_get_trial(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.GetTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.GetTrialRequest): request = vizier_service.GetTrialRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1667,8 +1663,8 @@ def sample_list_trials(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1676,10 +1672,8 @@ def sample_list_trials(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.ListTrialsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.ListTrialsRequest): request = vizier_service.ListTrialsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1780,10 +1774,8 @@ def sample_add_trial_measurement(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.AddTrialMeasurementRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.AddTrialMeasurementRequest): request = vizier_service.AddTrialMeasurementRequest(request) @@ -1869,10 +1861,8 @@ def sample_complete_trial(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.CompleteTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.CompleteTrialRequest): request = vizier_service.CompleteTrialRequest(request) @@ -1952,8 +1942,8 @@ def sample_delete_trial(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1961,10 +1951,8 @@ def sample_delete_trial(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.DeleteTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.DeleteTrialRequest): request = vizier_service.DeleteTrialRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2057,10 +2045,8 @@ def sample_check_trial_early_stopping_state(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.CheckTrialEarlyStoppingStateRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.CheckTrialEarlyStoppingStateRequest): request = vizier_service.CheckTrialEarlyStoppingStateRequest(request) @@ -2156,10 +2142,8 @@ def sample_stop_trial(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.StopTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.StopTrialRequest): request = vizier_service.StopTrialRequest(request) @@ -2251,8 +2235,8 @@ def sample_list_optimal_trials(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2260,10 +2244,8 @@ def sample_list_optimal_trials(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.ListOptimalTrialsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.ListOptimalTrialsRequest): request = vizier_service.ListOptimalTrialsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc.py index 7d721f412d..6d5ac597f9 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc.py @@ -62,7 +62,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -82,14 +82,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -99,11 +102,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -130,7 +133,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -171,7 +174,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc_asyncio.py index d6403249b4..3679685b7c 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -77,7 +79,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -107,7 +108,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -127,15 +128,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -145,11 +149,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -176,7 +180,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -216,7 +220,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -682,6 +688,86 @@ def list_optimal_trials( ) return self._stubs["list_optimal_trials"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_study: gapic_v1.method_async.wrap_method( + self.create_study, + default_timeout=None, + client_info=client_info, + ), + self.get_study: gapic_v1.method_async.wrap_method( + self.get_study, + default_timeout=None, + client_info=client_info, + ), + self.list_studies: gapic_v1.method_async.wrap_method( + self.list_studies, + default_timeout=None, + client_info=client_info, + ), + self.delete_study: gapic_v1.method_async.wrap_method( + self.delete_study, + default_timeout=None, + client_info=client_info, + ), + self.lookup_study: gapic_v1.method_async.wrap_method( + self.lookup_study, + default_timeout=None, + client_info=client_info, + ), + self.suggest_trials: gapic_v1.method_async.wrap_method( + self.suggest_trials, + default_timeout=None, + client_info=client_info, + ), + self.create_trial: gapic_v1.method_async.wrap_method( + self.create_trial, + default_timeout=None, + client_info=client_info, + ), + self.get_trial: gapic_v1.method_async.wrap_method( + self.get_trial, + default_timeout=None, + client_info=client_info, + ), + self.list_trials: gapic_v1.method_async.wrap_method( + self.list_trials, + default_timeout=None, + client_info=client_info, + ), + self.add_trial_measurement: gapic_v1.method_async.wrap_method( + self.add_trial_measurement, + default_timeout=None, + client_info=client_info, + ), + self.complete_trial: gapic_v1.method_async.wrap_method( + self.complete_trial, + default_timeout=None, + client_info=client_info, + ), + self.delete_trial: gapic_v1.method_async.wrap_method( + self.delete_trial, + default_timeout=None, + client_info=client_info, + ), + self.check_trial_early_stopping_state: gapic_v1.method_async.wrap_method( + self.check_trial_early_stopping_state, + default_timeout=None, + client_info=client_info, + ), + self.stop_trial: gapic_v1.method_async.wrap_method( + self.stop_trial, + default_timeout=None, + client_info=client_info, + ), + self.list_optimal_trials: gapic_v1.method_async.wrap_method( + self.list_optimal_trials, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index 179dddee00..268cf2a02a 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -44,6 +44,7 @@ Part, SafetyRating, SafetySetting, + SearchEntryPoint, Segment, VideoMetadata, HarmCategory, @@ -934,6 +935,7 @@ "Part", "SafetyRating", "SafetySetting", + "SearchEntryPoint", "Segment", "VideoMetadata", "HarmCategory", diff --git a/google/cloud/aiplatform_v1/types/accelerator_type.py b/google/cloud/aiplatform_v1/types/accelerator_type.py index 8ea4d56aba..041207152f 100644 --- a/google/cloud/aiplatform_v1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1/types/accelerator_type.py @@ -59,6 +59,8 @@ class AcceleratorType(proto.Enum): TPU v3. TPU_V4_POD (10): TPU v4. + TPU_V5_LITEPOD (12): + TPU v5. """ ACCELERATOR_TYPE_UNSPECIFIED = 0 NVIDIA_TESLA_K80 = 1 @@ -73,6 +75,7 @@ class AcceleratorType(proto.Enum): TPU_V2 = 6 TPU_V3 = 7 TPU_V4_POD = 10 + TPU_V5_LITEPOD = 12 __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/content.py b/google/cloud/aiplatform_v1/types/content.py index 0c4b6b408e..ce8431e7d4 100644 --- a/google/cloud/aiplatform_v1/types/content.py +++ b/google/cloud/aiplatform_v1/types/content.py @@ -42,6 +42,7 @@ "Segment", "GroundingAttribution", "GroundingMetadata", + "SearchEntryPoint", }, ) @@ -788,12 +789,19 @@ class Web(proto.Message): class GroundingMetadata(proto.Message): r"""Metadata returned to client when grounding is enabled. + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: web_search_queries (MutableSequence[str]): Optional. Web search queries for the following-up web search. grounding_attributions (MutableSequence[google.cloud.aiplatform_v1.types.GroundingAttribution]): Optional. List of grounding attributions. + search_entry_point (google.cloud.aiplatform_v1.types.SearchEntryPoint): + Optional. Google search entry for the + following-up web searches. + + This field is a member of `oneof`_ ``_search_entry_point``. """ web_search_queries: MutableSequence[str] = proto.RepeatedField( @@ -807,6 +815,34 @@ class GroundingMetadata(proto.Message): number=2, message="GroundingAttribution", ) + search_entry_point: "SearchEntryPoint" = proto.Field( + proto.MESSAGE, + number=4, + optional=True, + message="SearchEntryPoint", + ) + + +class SearchEntryPoint(proto.Message): + r"""Google search entry point. + + Attributes: + rendered_content (str): + Optional. Web content snippet that can be + embedded in a web page or an app webview. + sdk_blob (bytes): + Optional. Base64 encoded JSON representing + array of tuple. + """ + + rendered_content: str = proto.Field( + proto.STRING, + number=1, + ) + sdk_blob: bytes = proto.Field( + proto.BYTES, + number=2, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/endpoint.py b/google/cloud/aiplatform_v1/types/endpoint.py index 781b6da199..b62b1237fa 100644 --- a/google/cloud/aiplatform_v1/types/endpoint.py +++ b/google/cloud/aiplatform_v1/types/endpoint.py @@ -23,6 +23,7 @@ from google.cloud.aiplatform_v1.types import explanation from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import machine_resources +from google.cloud.aiplatform_v1.types import service_networking from google.protobuf import timestamp_pb2 # type: ignore @@ -121,6 +122,12 @@ class Endpoint(proto.Message): [network][google.cloud.aiplatform.v1.Endpoint.network] or [enable_private_service_connect][google.cloud.aiplatform.v1.Endpoint.enable_private_service_connect], can be set. + private_service_connect_config (google.cloud.aiplatform_v1.types.PrivateServiceConnectConfig): + Optional. Configuration for private service connect. + + [network][google.cloud.aiplatform.v1.Endpoint.network] and + [private_service_connect_config][google.cloud.aiplatform.v1.Endpoint.private_service_connect_config] + are mutually exclusive. model_deployment_monitoring_job (str): Output only. Resource name of the Model Monitoring job associated with this Endpoint if monitoring is enabled by @@ -186,6 +193,13 @@ class Endpoint(proto.Message): proto.BOOL, number=17, ) + private_service_connect_config: service_networking.PrivateServiceConnectConfig = ( + proto.Field( + proto.MESSAGE, + number=21, + message=service_networking.PrivateServiceConnectConfig, + ) + ) model_deployment_monitoring_job: str = proto.Field( proto.STRING, number=14, diff --git a/google/cloud/aiplatform_v1/types/feature_registry_service.py b/google/cloud/aiplatform_v1/types/feature_registry_service.py index 0ff6e182a4..fcdd9ebd33 100644 --- a/google/cloud/aiplatform_v1/types/feature_registry_service.py +++ b/google/cloud/aiplatform_v1/types/feature_registry_service.py @@ -49,7 +49,7 @@ class CreateFeatureGroupRequest(proto.Message): parent (str): Required. The resource name of the Location to create FeatureGroups. Format: - ``projects/{project}/locations/{location}'`` + ``projects/{project}/locations/{location}`` feature_group (google.cloud.aiplatform_v1.types.FeatureGroup): Required. The FeatureGroup to create. feature_group_id (str): diff --git a/google/cloud/aiplatform_v1/types/index_service.py b/google/cloud/aiplatform_v1/types/index_service.py index 1440f1b0f5..1af6c823d4 100644 --- a/google/cloud/aiplatform_v1/types/index_service.py +++ b/google/cloud/aiplatform_v1/types/index_service.py @@ -401,6 +401,8 @@ class RecordErrorType(proto.Enum): specified. INVALID_ENCODING (13): File is not in UTF_8 format. + INVALID_TOKEN_VALUE (15): + Token restrict value is invalid. """ ERROR_TYPE_UNSPECIFIED = 0 EMPTY_LINE = 1 @@ -416,6 +418,7 @@ class RecordErrorType(proto.Enum): MULTIPLE_VALUES = 11 INVALID_NUMERIC_VALUE = 12 INVALID_ENCODING = 13 + INVALID_TOKEN_VALUE = 15 error_type: "NearestNeighborSearchOperationMetadata.RecordError.RecordErrorType" = proto.Field( proto.ENUM, diff --git a/google/cloud/aiplatform_v1/types/notebook_runtime.py b/google/cloud/aiplatform_v1/types/notebook_runtime.py index ecc1b088bb..6f0c4023d5 100644 --- a/google/cloud/aiplatform_v1/types/notebook_runtime.py +++ b/google/cloud/aiplatform_v1/types/notebook_runtime.py @@ -65,7 +65,7 @@ class NotebookRuntimeTemplate(proto.Message): Attributes: name (str): - Output only. The resource name of the + The resource name of the NotebookRuntimeTemplate. display_name (str): Required. The display name of the diff --git a/google/cloud/aiplatform_v1/types/publisher_model.py b/google/cloud/aiplatform_v1/types/publisher_model.py index 157f783c83..8e92ed856b 100644 --- a/google/cloud/aiplatform_v1/types/publisher_model.py +++ b/google/cloud/aiplatform_v1/types/publisher_model.py @@ -423,6 +423,11 @@ class Deploy(proto.Message): Optional. The path to the directory containing the Model artifact and any of its supporting files. + deploy_task_name (str): + Optional. The name of the deploy task (e.g., + "text to image generation"). + + This field is a member of `oneof`_ ``_deploy_task_name``. title (str): Required. The title of the regional resource reference. @@ -466,6 +471,11 @@ class Deploy(proto.Message): proto.STRING, number=4, ) + deploy_task_name: str = proto.Field( + proto.STRING, + number=10, + optional=True, + ) title: str = proto.Field( proto.STRING, number=8, diff --git a/google/cloud/aiplatform_v1/types/tuning_job.py b/google/cloud/aiplatform_v1/types/tuning_job.py index f31c463c30..a386277c0a 100644 --- a/google/cloud/aiplatform_v1/types/tuning_job.py +++ b/google/cloud/aiplatform_v1/types/tuning_job.py @@ -46,7 +46,7 @@ class TuningJob(proto.Message): Attributes: base_model (str): - Model name for tuning, e.g., + The base model that is being tuned, e.g., "gemini-1.0-pro-002". This field is a member of `oneof`_ ``source_model``. @@ -397,11 +397,12 @@ class SupervisedHyperParameters(proto.Message): Attributes: epoch_count (int): - Optional. Number of training epoches for this - tuning job. + Optional. Number of complete passes the model + makes over the entire training dataset during + training. learning_rate_multiplier (float): - Optional. Learning rate multiplier for - tuning. + Optional. Multiplier for adjusting the + default learning rate. adapter_size (google.cloud.aiplatform_v1.types.SupervisedHyperParameters.AdapterSize): Optional. Adapter size for tuning. """ @@ -448,10 +449,12 @@ class SupervisedTuningSpec(proto.Message): Attributes: training_dataset_uri (str): Required. Cloud Storage path to file - containing training dataset for tuning. + containing training dataset for tuning. The + dataset must be formatted as a JSONL file. validation_dataset_uri (str): Optional. Cloud Storage path to file - containing validation dataset for tuning. + containing validation dataset for tuning. The + dataset must be formatted as a JSONL file. hyper_parameters (google.cloud.aiplatform_v1.types.SupervisedHyperParameters): Optional. Hyperparameters for SFT. """ diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 784c2ff027..a465fc6168 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -119,6 +119,7 @@ from .types.content import Part from .types.content import SafetyRating from .types.content import SafetySetting +from .types.content import SearchEntryPoint from .types.content import Segment from .types.content import VideoMetadata from .types.content import HarmCategory @@ -170,6 +171,7 @@ from .types.dataset_service import SearchDataItemsRequest from .types.dataset_service import SearchDataItemsResponse from .types.dataset_service import UpdateDatasetRequest +from .types.dataset_service import UpdateDatasetVersionRequest from .types.dataset_version import DatasetVersion from .types.deployed_index_ref import DeployedIndexRef from .types.deployed_model_ref import DeployedModelRef @@ -729,6 +731,7 @@ from .types.nas_job import NasTrialDetail from .types.network_spec import NetworkSpec from .types.notebook_euc_config import NotebookEucConfig +from .types.notebook_execution_job import NotebookExecutionJob from .types.notebook_idle_shutdown_config import NotebookIdleShutdownConfig from .types.notebook_runtime import NotebookRuntime from .types.notebook_runtime import NotebookRuntimeTemplate @@ -736,12 +739,17 @@ from .types.notebook_runtime_template_ref import NotebookRuntimeTemplateRef from .types.notebook_service import AssignNotebookRuntimeOperationMetadata from .types.notebook_service import AssignNotebookRuntimeRequest +from .types.notebook_service import CreateNotebookExecutionJobRequest from .types.notebook_service import CreateNotebookRuntimeTemplateOperationMetadata from .types.notebook_service import CreateNotebookRuntimeTemplateRequest +from .types.notebook_service import DeleteNotebookExecutionJobRequest from .types.notebook_service import DeleteNotebookRuntimeRequest from .types.notebook_service import DeleteNotebookRuntimeTemplateRequest +from .types.notebook_service import GetNotebookExecutionJobRequest from .types.notebook_service import GetNotebookRuntimeRequest from .types.notebook_service import GetNotebookRuntimeTemplateRequest +from .types.notebook_service import ListNotebookExecutionJobsRequest +from .types.notebook_service import ListNotebookExecutionJobsResponse from .types.notebook_service import ListNotebookRuntimesRequest from .types.notebook_service import ListNotebookRuntimesResponse from .types.notebook_service import ListNotebookRuntimeTemplatesRequest @@ -752,6 +760,7 @@ from .types.notebook_service import UpgradeNotebookRuntimeOperationMetadata from .types.notebook_service import UpgradeNotebookRuntimeRequest from .types.notebook_service import UpgradeNotebookRuntimeResponse +from .types.notebook_service import NotebookExecutionJobView from .types.openapi import Schema from .types.openapi import Type from .types.operation import DeleteOperationMetadata @@ -797,7 +806,6 @@ from .types.pipeline_service import ListTrainingPipelinesRequest from .types.pipeline_service import ListTrainingPipelinesResponse from .types.pipeline_state import PipelineState -from .types.prediction_service import ChatCompletionsRequest from .types.prediction_service import CountTokensRequest from .types.prediction_service import CountTokensResponse from .types.prediction_service import DirectPredictRequest @@ -1084,7 +1092,6 @@ "CancelPipelineJobRequest", "CancelTrainingPipelineRequest", "Candidate", - "ChatCompletionsRequest", "CheckTrialEarlyStoppingStateMetatdata", "CheckTrialEarlyStoppingStateRequest", "CheckTrialEarlyStoppingStateResponse", @@ -1146,6 +1153,7 @@ "CreateModelMonitorRequest", "CreateModelMonitoringJobRequest", "CreateNasJobRequest", + "CreateNotebookExecutionJobRequest", "CreateNotebookRuntimeTemplateOperationMetadata", "CreateNotebookRuntimeTemplateRequest", "CreatePersistentResourceOperationMetadata", @@ -1209,6 +1217,7 @@ "DeleteModelRequest", "DeleteModelVersionRequest", "DeleteNasJobRequest", + "DeleteNotebookExecutionJobRequest", "DeleteNotebookRuntimeRequest", "DeleteNotebookRuntimeTemplateRequest", "DeleteOperationMetadata", @@ -1377,6 +1386,7 @@ "GetModelRequest", "GetNasJobRequest", "GetNasTrialDetailRequest", + "GetNotebookExecutionJobRequest", "GetNotebookRuntimeRequest", "GetNotebookRuntimeTemplateRequest", "GetPersistentResourceRequest", @@ -1502,6 +1512,8 @@ "ListNasJobsResponse", "ListNasTrialDetailsRequest", "ListNasTrialDetailsResponse", + "ListNotebookExecutionJobsRequest", + "ListNotebookExecutionJobsResponse", "ListNotebookRuntimeTemplatesRequest", "ListNotebookRuntimeTemplatesResponse", "ListNotebookRuntimesRequest", @@ -1606,6 +1618,8 @@ "NetworkSpec", "NfsMount", "NotebookEucConfig", + "NotebookExecutionJob", + "NotebookExecutionJobView", "NotebookIdleShutdownConfig", "NotebookRuntime", "NotebookRuntimeTemplate", @@ -1748,6 +1762,7 @@ "Schema", "SearchDataItemsRequest", "SearchDataItemsResponse", + "SearchEntryPoint", "SearchFeaturesRequest", "SearchFeaturesResponse", "SearchMigratableResourcesRequest", @@ -1857,6 +1872,7 @@ "UpdateArtifactRequest", "UpdateContextRequest", "UpdateDatasetRequest", + "UpdateDatasetVersionRequest", "UpdateDeploymentResourcePoolOperationMetadata", "UpdateEndpointRequest", "UpdateEntityTypeRequest", diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json index 6ebc20fc5d..d99cd0cbdb 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json @@ -99,6 +99,11 @@ "methods": [ "update_dataset" ] + }, + "UpdateDatasetVersion": { + "methods": [ + "update_dataset_version" + ] } } }, @@ -194,6 +199,11 @@ "methods": [ "update_dataset" ] + }, + "UpdateDatasetVersion": { + "methods": [ + "update_dataset_version" + ] } } }, @@ -289,6 +299,11 @@ "methods": [ "update_dataset" ] + }, + "UpdateDatasetVersion": { + "methods": [ + "update_dataset_version" + ] } } } @@ -3524,6 +3539,11 @@ "create_notebook_runtime_template" ] }, + "DeleteNotebookExecutionJob": { + "methods": [ + "delete_notebook_execution_job" + ] + }, "DeleteNotebookRuntime": { "methods": [ "delete_notebook_runtime" @@ -3534,6 +3554,11 @@ "delete_notebook_runtime_template" ] }, + "GetNotebookExecutionJob": { + "methods": [ + "get_notebook_execution_job" + ] + }, "GetNotebookRuntime": { "methods": [ "get_notebook_runtime" @@ -3544,6 +3569,11 @@ "get_notebook_runtime_template" ] }, + "ListNotebookExecutionJobs": { + "methods": [ + "list_notebook_execution_jobs" + ] + }, "ListNotebookRuntimeTemplates": { "methods": [ "list_notebook_runtime_templates" @@ -3579,6 +3609,11 @@ "create_notebook_runtime_template" ] }, + "DeleteNotebookExecutionJob": { + "methods": [ + "delete_notebook_execution_job" + ] + }, "DeleteNotebookRuntime": { "methods": [ "delete_notebook_runtime" @@ -3589,6 +3624,11 @@ "delete_notebook_runtime_template" ] }, + "GetNotebookExecutionJob": { + "methods": [ + "get_notebook_execution_job" + ] + }, "GetNotebookRuntime": { "methods": [ "get_notebook_runtime" @@ -3599,6 +3639,11 @@ "get_notebook_runtime_template" ] }, + "ListNotebookExecutionJobs": { + "methods": [ + "list_notebook_execution_jobs" + ] + }, "ListNotebookRuntimeTemplates": { "methods": [ "list_notebook_runtime_templates" @@ -3634,6 +3679,11 @@ "create_notebook_runtime_template" ] }, + "DeleteNotebookExecutionJob": { + "methods": [ + "delete_notebook_execution_job" + ] + }, "DeleteNotebookRuntime": { "methods": [ "delete_notebook_runtime" @@ -3644,6 +3694,11 @@ "delete_notebook_runtime_template" ] }, + "GetNotebookExecutionJob": { + "methods": [ + "get_notebook_execution_job" + ] + }, "GetNotebookRuntime": { "methods": [ "get_notebook_runtime" @@ -3654,6 +3709,11 @@ "get_notebook_runtime_template" ] }, + "ListNotebookExecutionJobs": { + "methods": [ + "list_notebook_execution_jobs" + ] + }, "ListNotebookRuntimeTemplates": { "methods": [ "list_notebook_runtime_templates" @@ -3991,11 +4051,6 @@ "grpc": { "libraryClient": "PredictionServiceClient", "rpcs": { - "ChatCompletions": { - "methods": [ - "chat_completions" - ] - }, "CountTokens": { "methods": [ "count_tokens" @@ -4066,11 +4121,6 @@ "grpc-async": { "libraryClient": "PredictionServiceAsyncClient", "rpcs": { - "ChatCompletions": { - "methods": [ - "chat_completions" - ] - }, "CountTokens": { "methods": [ "count_tokens" @@ -4141,11 +4191,6 @@ "rest": { "libraryClient": "PredictionServiceClient", "rpcs": { - "ChatCompletions": { - "methods": [ - "chat_completions" - ] - }, "CountTokens": { "methods": [ "count_tokens" diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index 42cf7e26c0..704f102748 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -229,7 +230,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, DatasetServiceTransport, Callable[..., DatasetServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -241,9 +244,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DatasetServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DatasetServiceTransport,Callable[..., DatasetServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DatasetServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -371,8 +376,8 @@ async def sample_create_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: raise ValueError( @@ -380,7 +385,10 @@ async def sample_create_dataset(): "the individual field arguments should be set." ) - request = dataset_service.CreateDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.CreateDatasetRequest): + request = dataset_service.CreateDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -391,11 +399,9 @@ async def sample_create_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_dataset, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -486,8 +492,8 @@ async def sample_get_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -495,7 +501,10 @@ async def sample_get_dataset(): "the individual field arguments should be set." ) - request = dataset_service.GetDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.GetDatasetRequest): + request = dataset_service.GetDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -504,11 +513,9 @@ async def sample_get_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_dataset, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -610,8 +617,8 @@ async def sample_update_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -619,7 +626,10 @@ async def sample_update_dataset(): "the individual field arguments should be set." ) - request = dataset_service.UpdateDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.UpdateDatasetRequest): + request = dataset_service.UpdateDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -630,11 +640,9 @@ async def sample_update_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_dataset, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -723,8 +731,8 @@ async def sample_list_datasets(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -732,7 +740,10 @@ async def sample_list_datasets(): "the individual field arguments should be set." ) - request = dataset_service.ListDatasetsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListDatasetsRequest): + request = dataset_service.ListDatasetsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -741,11 +752,9 @@ async def sample_list_datasets(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_datasets, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_datasets + ] # Certain fields should be provided within the metadata header; # add these here. @@ -852,8 +861,8 @@ async def sample_delete_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -861,7 +870,10 @@ async def sample_delete_dataset(): "the individual field arguments should be set." ) - request = dataset_service.DeleteDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.DeleteDatasetRequest): + request = dataset_service.DeleteDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -870,11 +882,9 @@ async def sample_delete_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_dataset, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -985,8 +995,8 @@ async def sample_import_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: raise ValueError( @@ -994,7 +1004,10 @@ async def sample_import_data(): "the individual field arguments should be set." ) - request = dataset_service.ImportDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ImportDataRequest): + request = dataset_service.ImportDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1005,11 +1018,9 @@ async def sample_import_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_data, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1118,8 +1129,8 @@ async def sample_export_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: raise ValueError( @@ -1127,7 +1138,10 @@ async def sample_export_data(): "the individual field arguments should be set." ) - request = dataset_service.ExportDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ExportDataRequest): + request = dataset_service.ExportDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1138,11 +1152,9 @@ async def sample_export_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_data, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1257,8 +1269,8 @@ async def sample_create_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset_version]) if request is not None and has_flattened_params: raise ValueError( @@ -1266,7 +1278,10 @@ async def sample_create_dataset_version(): "the individual field arguments should be set." ) - request = dataset_service.CreateDatasetVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.CreateDatasetVersionRequest): + request = dataset_service.CreateDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1277,11 +1292,9 @@ async def sample_create_dataset_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_dataset_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_dataset_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1311,6 +1324,131 @@ async def sample_create_dataset_version(): # Done; return the response. return response + async def update_dataset_version( + self, + request: Optional[ + Union[dataset_service.UpdateDatasetVersionRequest, dict] + ] = None, + *, + dataset_version: Optional[gca_dataset_version.DatasetVersion] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset_version.DatasetVersion: + r"""Updates a DatasetVersion. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + async def sample_update_dataset_version(): + # Create a client + client = aiplatform_v1beta1.DatasetServiceAsyncClient() + + # Initialize request argument(s) + dataset_version = aiplatform_v1beta1.DatasetVersion() + dataset_version.metadata.null_value = "NULL_VALUE" + + request = aiplatform_v1beta1.UpdateDatasetVersionRequest( + dataset_version=dataset_version, + ) + + # Make the request + response = await client.update_dataset_version(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1beta1.types.UpdateDatasetVersionRequest, dict]]): + The request object. Request message for + [DatasetService.UpdateDatasetVersion][google.cloud.aiplatform.v1beta1.DatasetService.UpdateDatasetVersion]. + dataset_version (:class:`google.cloud.aiplatform_v1beta1.types.DatasetVersion`): + Required. The DatasetVersion which + replaces the resource on the server. + + This corresponds to the ``dataset_version`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + Updatable fields: + + - ``display_name`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.DatasetVersion: + Describes the dataset version. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([dataset_version, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.UpdateDatasetVersionRequest): + request = dataset_service.UpdateDatasetVersionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if dataset_version is not None: + request.dataset_version = dataset_version + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_dataset_version + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("dataset_version.name", request.dataset_version.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + async def delete_dataset_version( self, request: Optional[ @@ -1389,8 +1527,8 @@ async def sample_delete_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1398,7 +1536,10 @@ async def sample_delete_dataset_version(): "the individual field arguments should be set." ) - request = dataset_service.DeleteDatasetVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.DeleteDatasetVersionRequest): + request = dataset_service.DeleteDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1407,11 +1548,9 @@ async def sample_delete_dataset_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_dataset_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_dataset_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1501,8 +1640,8 @@ async def sample_get_dataset_version(): Describes the dataset version. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1510,7 +1649,10 @@ async def sample_get_dataset_version(): "the individual field arguments should be set." ) - request = dataset_service.GetDatasetVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.GetDatasetVersionRequest): + request = dataset_service.GetDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1519,11 +1661,9 @@ async def sample_get_dataset_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_dataset_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_dataset_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1613,8 +1753,8 @@ async def sample_list_dataset_versions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1622,7 +1762,10 @@ async def sample_list_dataset_versions(): "the individual field arguments should be set." ) - request = dataset_service.ListDatasetVersionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListDatasetVersionsRequest): + request = dataset_service.ListDatasetVersionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1631,11 +1774,9 @@ async def sample_list_dataset_versions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_dataset_versions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_dataset_versions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1737,8 +1878,8 @@ async def sample_restore_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1746,7 +1887,10 @@ async def sample_restore_dataset_version(): "the individual field arguments should be set." ) - request = dataset_service.RestoreDatasetVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.RestoreDatasetVersionRequest): + request = dataset_service.RestoreDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1755,11 +1899,9 @@ async def sample_restore_dataset_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.restore_dataset_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.restore_dataset_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1855,8 +1997,8 @@ async def sample_list_data_items(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1864,7 +2006,10 @@ async def sample_list_data_items(): "the individual field arguments should be set." ) - request = dataset_service.ListDataItemsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListDataItemsRequest): + request = dataset_service.ListDataItemsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1873,11 +2018,9 @@ async def sample_list_data_items(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_data_items, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_data_items + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1966,15 +2109,16 @@ async def sample_search_data_items(): """ # Create or coerce a protobuf request object. - request = dataset_service.SearchDataItemsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.SearchDataItemsRequest): + request = dataset_service.SearchDataItemsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_data_items, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_data_items + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2071,8 +2215,8 @@ async def sample_list_saved_queries(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2080,7 +2224,10 @@ async def sample_list_saved_queries(): "the individual field arguments should be set." ) - request = dataset_service.ListSavedQueriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListSavedQueriesRequest): + request = dataset_service.ListSavedQueriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2089,11 +2236,9 @@ async def sample_list_saved_queries(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_saved_queries, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_saved_queries + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2200,8 +2345,8 @@ async def sample_delete_saved_query(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2209,7 +2354,10 @@ async def sample_delete_saved_query(): "the individual field arguments should be set." ) - request = dataset_service.DeleteSavedQueryRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.DeleteSavedQueryRequest): + request = dataset_service.DeleteSavedQueryRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2218,11 +2366,9 @@ async def sample_delete_saved_query(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_saved_query, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_saved_query + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2314,8 +2460,8 @@ async def sample_get_annotation_spec(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2323,7 +2469,10 @@ async def sample_get_annotation_spec(): "the individual field arguments should be set." ) - request = dataset_service.GetAnnotationSpecRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.GetAnnotationSpecRequest): + request = dataset_service.GetAnnotationSpecRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2332,11 +2481,9 @@ async def sample_get_annotation_spec(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_annotation_spec, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_annotation_spec + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2424,8 +2571,8 @@ async def sample_list_annotations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2433,7 +2580,10 @@ async def sample_list_annotations(): "the individual field arguments should be set." ) - request = dataset_service.ListAnnotationsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.ListAnnotationsRequest): + request = dataset_service.ListAnnotationsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2442,11 +2592,9 @@ async def sample_list_annotations(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_annotations, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_annotations + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 2f7d920c60..7c81d0810b 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -672,7 +673,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, DatasetServiceTransport]] = None, + transport: Optional[ + Union[str, DatasetServiceTransport, Callable[..., DatasetServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -684,9 +687,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, DatasetServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DatasetServiceTransport,Callable[..., DatasetServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DatasetServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -798,8 +803,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[DatasetServiceTransport], Callable[..., DatasetServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., DatasetServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -892,8 +904,8 @@ def sample_create_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: raise ValueError( @@ -901,10 +913,8 @@ def sample_create_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.CreateDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.CreateDatasetRequest): request = dataset_service.CreateDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1007,8 +1017,8 @@ def sample_get_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1016,10 +1026,8 @@ def sample_get_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.GetDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.GetDatasetRequest): request = dataset_service.GetDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1131,8 +1139,8 @@ def sample_update_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1140,10 +1148,8 @@ def sample_update_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.UpdateDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.UpdateDatasetRequest): request = dataset_service.UpdateDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1244,8 +1250,8 @@ def sample_list_datasets(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1253,10 +1259,8 @@ def sample_list_datasets(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListDatasetsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListDatasetsRequest): request = dataset_service.ListDatasetsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1373,8 +1377,8 @@ def sample_delete_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1382,10 +1386,8 @@ def sample_delete_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.DeleteDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.DeleteDatasetRequest): request = dataset_service.DeleteDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1506,8 +1508,8 @@ def sample_import_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: raise ValueError( @@ -1515,10 +1517,8 @@ def sample_import_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ImportDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ImportDataRequest): request = dataset_service.ImportDataRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1639,8 +1639,8 @@ def sample_export_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: raise ValueError( @@ -1648,10 +1648,8 @@ def sample_export_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ExportDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ExportDataRequest): request = dataset_service.ExportDataRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1778,8 +1776,8 @@ def sample_create_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset_version]) if request is not None and has_flattened_params: raise ValueError( @@ -1787,10 +1785,8 @@ def sample_create_dataset_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.CreateDatasetVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.CreateDatasetVersionRequest): request = dataset_service.CreateDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1832,6 +1828,128 @@ def sample_create_dataset_version(): # Done; return the response. return response + def update_dataset_version( + self, + request: Optional[ + Union[dataset_service.UpdateDatasetVersionRequest, dict] + ] = None, + *, + dataset_version: Optional[gca_dataset_version.DatasetVersion] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset_version.DatasetVersion: + r"""Updates a DatasetVersion. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_update_dataset_version(): + # Create a client + client = aiplatform_v1beta1.DatasetServiceClient() + + # Initialize request argument(s) + dataset_version = aiplatform_v1beta1.DatasetVersion() + dataset_version.metadata.null_value = "NULL_VALUE" + + request = aiplatform_v1beta1.UpdateDatasetVersionRequest( + dataset_version=dataset_version, + ) + + # Make the request + response = client.update_dataset_version(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.UpdateDatasetVersionRequest, dict]): + The request object. Request message for + [DatasetService.UpdateDatasetVersion][google.cloud.aiplatform.v1beta1.DatasetService.UpdateDatasetVersion]. + dataset_version (google.cloud.aiplatform_v1beta1.types.DatasetVersion): + Required. The DatasetVersion which + replaces the resource on the server. + + This corresponds to the ``dataset_version`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + Updatable fields: + + - ``display_name`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.DatasetVersion: + Describes the dataset version. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([dataset_version, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, dataset_service.UpdateDatasetVersionRequest): + request = dataset_service.UpdateDatasetVersionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if dataset_version is not None: + request.dataset_version = dataset_version + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_dataset_version] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("dataset_version.name", request.dataset_version.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def delete_dataset_version( self, request: Optional[ @@ -1910,8 +2028,8 @@ def sample_delete_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1919,10 +2037,8 @@ def sample_delete_dataset_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.DeleteDatasetVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.DeleteDatasetVersionRequest): request = dataset_service.DeleteDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2022,8 +2138,8 @@ def sample_get_dataset_version(): Describes the dataset version. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2031,10 +2147,8 @@ def sample_get_dataset_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.GetDatasetVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.GetDatasetVersionRequest): request = dataset_service.GetDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2134,8 +2248,8 @@ def sample_list_dataset_versions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2143,10 +2257,8 @@ def sample_list_dataset_versions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListDatasetVersionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListDatasetVersionsRequest): request = dataset_service.ListDatasetVersionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2258,8 +2370,8 @@ def sample_restore_dataset_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2267,10 +2379,8 @@ def sample_restore_dataset_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.RestoreDatasetVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.RestoreDatasetVersionRequest): request = dataset_service.RestoreDatasetVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2376,8 +2486,8 @@ def sample_list_data_items(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2385,10 +2495,8 @@ def sample_list_data_items(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListDataItemsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListDataItemsRequest): request = dataset_service.ListDataItemsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2487,10 +2595,8 @@ def sample_search_data_items(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.SearchDataItemsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.SearchDataItemsRequest): request = dataset_service.SearchDataItemsRequest(request) @@ -2593,8 +2699,8 @@ def sample_list_saved_queries(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2602,10 +2708,8 @@ def sample_list_saved_queries(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListSavedQueriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListSavedQueriesRequest): request = dataset_service.ListSavedQueriesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2722,8 +2826,8 @@ def sample_delete_saved_query(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2731,10 +2835,8 @@ def sample_delete_saved_query(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.DeleteSavedQueryRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.DeleteSavedQueryRequest): request = dataset_service.DeleteSavedQueryRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2836,8 +2938,8 @@ def sample_get_annotation_spec(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2845,10 +2947,8 @@ def sample_get_annotation_spec(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.GetAnnotationSpecRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.GetAnnotationSpecRequest): request = dataset_service.GetAnnotationSpecRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2946,8 +3046,8 @@ def sample_list_annotations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2955,10 +3055,8 @@ def sample_list_annotations(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListAnnotationsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, dataset_service.ListAnnotationsRequest): request = dataset_service.ListAnnotationsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index 40422091fd..628905029d 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -32,6 +32,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset from google.cloud.aiplatform_v1beta1.types import dataset_service from google.cloud.aiplatform_v1beta1.types import dataset_version +from google.cloud.aiplatform_v1beta1.types import dataset_version as gca_dataset_version from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -176,6 +177,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.update_dataset_version: gapic_v1.method.wrap_method( + self.update_dataset_version, + default_timeout=None, + client_info=client_info, + ), self.delete_dataset_version: gapic_v1.method.wrap_method( self.delete_dataset_version, default_timeout=None, @@ -317,6 +323,18 @@ def create_dataset_version( ]: raise NotImplementedError() + @property + def update_dataset_version( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetVersionRequest], + Union[ + gca_dataset_version.DatasetVersion, + Awaitable[gca_dataset_version.DatasetVersion], + ], + ]: + raise NotImplementedError() + @property def delete_dataset_version( self, diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 583c278344..18934cb685 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -30,6 +30,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset from google.cloud.aiplatform_v1beta1.types import dataset_service from google.cloud.aiplatform_v1beta1.types import dataset_version +from google.cloud.aiplatform_v1beta1.types import dataset_version as gca_dataset_version from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -60,7 +61,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -80,14 +81,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -97,11 +101,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -128,7 +132,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -169,7 +173,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -466,6 +472,35 @@ def create_dataset_version( ) return self._stubs["create_dataset_version"] + @property + def update_dataset_version( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetVersionRequest], + gca_dataset_version.DatasetVersion, + ]: + r"""Return a callable for the update dataset version method over gRPC. + + Updates a DatasetVersion. + + Returns: + Callable[[~.UpdateDatasetVersionRequest], + ~.DatasetVersion]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_dataset_version" not in self._stubs: + self._stubs["update_dataset_version"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDatasetVersion", + request_serializer=dataset_service.UpdateDatasetVersionRequest.serialize, + response_deserializer=gca_dataset_version.DatasetVersion.deserialize, + ) + return self._stubs["update_dataset_version"] + @property def delete_dataset_version( self, diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py index f9de620c74..7a3f1ec776 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -30,6 +32,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset from google.cloud.aiplatform_v1beta1.types import dataset_service from google.cloud.aiplatform_v1beta1.types import dataset_version +from google.cloud.aiplatform_v1beta1.types import dataset_version as gca_dataset_version from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -75,7 +78,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -105,7 +107,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -125,15 +127,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -143,11 +148,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -174,7 +179,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -214,7 +219,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -483,6 +490,35 @@ def create_dataset_version( ) return self._stubs["create_dataset_version"] + @property + def update_dataset_version( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetVersionRequest], + Awaitable[gca_dataset_version.DatasetVersion], + ]: + r"""Return a callable for the update dataset version method over gRPC. + + Updates a DatasetVersion. + + Returns: + Callable[[~.UpdateDatasetVersionRequest], + Awaitable[~.DatasetVersion]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_dataset_version" not in self._stubs: + self._stubs["update_dataset_version"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDatasetVersion", + request_serializer=dataset_service.UpdateDatasetVersionRequest.serialize, + response_deserializer=gca_dataset_version.DatasetVersion.deserialize, + ) + return self._stubs["update_dataset_version"] + @property def delete_dataset_version( self, @@ -772,6 +808,106 @@ def list_annotations( ) return self._stubs["list_annotations"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_dataset: gapic_v1.method_async.wrap_method( + self.create_dataset, + default_timeout=5.0, + client_info=client_info, + ), + self.get_dataset: gapic_v1.method_async.wrap_method( + self.get_dataset, + default_timeout=5.0, + client_info=client_info, + ), + self.update_dataset: gapic_v1.method_async.wrap_method( + self.update_dataset, + default_timeout=5.0, + client_info=client_info, + ), + self.list_datasets: gapic_v1.method_async.wrap_method( + self.list_datasets, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_dataset: gapic_v1.method_async.wrap_method( + self.delete_dataset, + default_timeout=5.0, + client_info=client_info, + ), + self.import_data: gapic_v1.method_async.wrap_method( + self.import_data, + default_timeout=5.0, + client_info=client_info, + ), + self.export_data: gapic_v1.method_async.wrap_method( + self.export_data, + default_timeout=5.0, + client_info=client_info, + ), + self.create_dataset_version: gapic_v1.method_async.wrap_method( + self.create_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.update_dataset_version: gapic_v1.method_async.wrap_method( + self.update_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.delete_dataset_version: gapic_v1.method_async.wrap_method( + self.delete_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.get_dataset_version: gapic_v1.method_async.wrap_method( + self.get_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.list_dataset_versions: gapic_v1.method_async.wrap_method( + self.list_dataset_versions, + default_timeout=None, + client_info=client_info, + ), + self.restore_dataset_version: gapic_v1.method_async.wrap_method( + self.restore_dataset_version, + default_timeout=None, + client_info=client_info, + ), + self.list_data_items: gapic_v1.method_async.wrap_method( + self.list_data_items, + default_timeout=5.0, + client_info=client_info, + ), + self.search_data_items: gapic_v1.method_async.wrap_method( + self.search_data_items, + default_timeout=None, + client_info=client_info, + ), + self.list_saved_queries: gapic_v1.method_async.wrap_method( + self.list_saved_queries, + default_timeout=None, + client_info=client_info, + ), + self.delete_saved_query: gapic_v1.method_async.wrap_method( + self.delete_saved_query, + default_timeout=None, + client_info=client_info, + ), + self.get_annotation_spec: gapic_v1.method_async.wrap_method( + self.get_annotation_spec, + default_timeout=5.0, + client_info=client_info, + ), + self.list_annotations: gapic_v1.method_async.wrap_method( + self.list_annotations, + default_timeout=5.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/rest.py index 157eeb1488..391fd0ceb4 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/rest.py @@ -48,6 +48,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset from google.cloud.aiplatform_v1beta1.types import dataset_service from google.cloud.aiplatform_v1beta1.types import dataset_version +from google.cloud.aiplatform_v1beta1.types import dataset_version as gca_dataset_version from google.longrunning import operations_pb2 # type: ignore from .base import ( @@ -222,6 +223,14 @@ def post_update_dataset(self, response): logging.log(f"Received response: {response}") return response + def pre_update_dataset_version(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_dataset_version(self, response): + logging.log(f"Received response: {response}") + return response + transport = DatasetServiceRestTransport(interceptor=MyCustomDatasetServiceInterceptor()) client = DatasetServiceClient(transport=transport) @@ -638,6 +647,29 @@ def post_update_dataset(self, response: gca_dataset.Dataset) -> gca_dataset.Data """ return response + def pre_update_dataset_version( + self, + request: dataset_service.UpdateDatasetVersionRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[dataset_service.UpdateDatasetVersionRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for update_dataset_version + + Override in a subclass to manipulate the request or metadata + before they are sent to the DatasetService server. + """ + return request, metadata + + def post_update_dataset_version( + self, response: gca_dataset_version.DatasetVersion + ) -> gca_dataset_version.DatasetVersion: + """Post-rpc interceptor for update_dataset_version + + Override in a subclass to manipulate the response + after it is returned by the DatasetService server but before + it is returned to user code. + """ + return response + def pre_get_location( self, request: locations_pb2.GetLocationRequest, @@ -1185,10 +1217,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1559,10 +1587,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1921,10 +1945,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2299,10 +2319,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2677,10 +2693,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4435,6 +4447,101 @@ def __call__( resp = self._interceptor.post_update_dataset(resp) return resp + class _UpdateDatasetVersion(DatasetServiceRestStub): + def __hash__(self): + return hash("UpdateDatasetVersion") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask": {}, + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: dataset_service.UpdateDatasetVersionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset_version.DatasetVersion: + r"""Call the update dataset version method over HTTP. + + Args: + request (~.dataset_service.UpdateDatasetVersionRequest): + The request object. Request message for + [DatasetService.UpdateDatasetVersion][google.cloud.aiplatform.v1beta1.DatasetService.UpdateDatasetVersion]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_dataset_version.DatasetVersion: + Describes the dataset version. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1beta1/{dataset_version.name=projects/*/locations/*/datasets/*/datasetVersions/*}", + "body": "dataset_version", + }, + ] + request, metadata = self._interceptor.pre_update_dataset_version( + request, metadata + ) + pb_request = dataset_service.UpdateDatasetVersionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gca_dataset_version.DatasetVersion() + pb_resp = gca_dataset_version.DatasetVersion.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_dataset_version(resp) + return resp + @property def create_dataset( self, @@ -4605,6 +4712,17 @@ def update_dataset( # In C++ this would require a dynamic_cast return self._UpdateDataset(self._session, self._host, self._interceptor) # type: ignore + @property + def update_dataset_version( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetVersionRequest], + gca_dataset_version.DatasetVersion, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateDatasetVersion(self._session, self._host, self._interceptor) # type: ignore + @property def get_location(self): return self._GetLocation(self._session, self._host, self._interceptor) # type: ignore @@ -5371,10 +5489,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -5802,10 +5916,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -6224,10 +6334,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -6663,10 +6769,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -7102,10 +7204,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/async_client.py index f8d43f0544..abd9a1e94a 100644 --- a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -231,7 +232,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, DeploymentResourcePoolServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + DeploymentResourcePoolServiceTransport, + Callable[..., DeploymentResourcePoolServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -243,9 +250,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DeploymentResourcePoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DeploymentResourcePoolServiceTransport,Callable[..., DeploymentResourcePoolServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DeploymentResourcePoolServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -393,8 +402,8 @@ async def sample_create_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, deployment_resource_pool, deployment_resource_pool_id] ) @@ -404,9 +413,17 @@ async def sample_create_deployment_resource_pool(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.CreateDeploymentResourcePoolRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, + deployment_resource_pool_service.CreateDeploymentResourcePoolRequest, + ): + request = ( + deployment_resource_pool_service.CreateDeploymentResourcePoolRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -419,11 +436,9 @@ async def sample_create_deployment_resource_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_deployment_resource_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_deployment_resource_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -521,8 +536,8 @@ async def sample_get_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -530,9 +545,14 @@ async def sample_get_deployment_resource_pool(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.GetDeploymentResourcePoolRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, deployment_resource_pool_service.GetDeploymentResourcePoolRequest + ): + request = deployment_resource_pool_service.GetDeploymentResourcePoolRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -541,11 +561,9 @@ async def sample_get_deployment_resource_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_deployment_resource_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_deployment_resource_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -638,8 +656,8 @@ async def sample_list_deployment_resource_pools(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -647,9 +665,16 @@ async def sample_list_deployment_resource_pools(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.ListDeploymentResourcePoolsRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, deployment_resource_pool_service.ListDeploymentResourcePoolsRequest + ): + request = ( + deployment_resource_pool_service.ListDeploymentResourcePoolsRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -658,11 +683,9 @@ async def sample_list_deployment_resource_pools(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_deployment_resource_pools, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_deployment_resource_pools + ] # Certain fields should be provided within the metadata header; # add these here. @@ -774,8 +797,8 @@ async def sample_delete_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -783,9 +806,17 @@ async def sample_delete_deployment_resource_pool(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, + deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest, + ): + request = ( + deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -794,11 +825,9 @@ async def sample_delete_deployment_resource_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_deployment_resource_pool, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_deployment_resource_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -897,8 +926,8 @@ async def sample_query_deployed_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([deployment_resource_pool]) if request is not None and has_flattened_params: raise ValueError( @@ -906,7 +935,14 @@ async def sample_query_deployed_models(): "the individual field arguments should be set." ) - request = deployment_resource_pool_service.QueryDeployedModelsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, deployment_resource_pool_service.QueryDeployedModelsRequest + ): + request = deployment_resource_pool_service.QueryDeployedModelsRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -915,11 +951,9 @@ async def sample_query_deployed_models(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_deployed_models, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_deployed_models + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/client.py index edb404bc6c..6e5cf4a4bf 100644 --- a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -595,7 +596,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, DeploymentResourcePoolServiceTransport]] = None, + transport: Optional[ + Union[ + str, + DeploymentResourcePoolServiceTransport, + Callable[..., DeploymentResourcePoolServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -607,9 +614,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, DeploymentResourcePoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,DeploymentResourcePoolServiceTransport,Callable[..., DeploymentResourcePoolServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DeploymentResourcePoolServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -727,8 +736,18 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[DeploymentResourcePoolServiceTransport], + Callable[..., DeploymentResourcePoolServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast( + Callable[..., DeploymentResourcePoolServiceTransport], transport + ) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -841,8 +860,8 @@ def sample_create_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, deployment_resource_pool, deployment_resource_pool_id] ) @@ -852,10 +871,8 @@ def sample_create_deployment_resource_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.CreateDeploymentResourcePoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.CreateDeploymentResourcePoolRequest, @@ -976,8 +993,8 @@ def sample_get_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -985,10 +1002,8 @@ def sample_get_deployment_resource_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.GetDeploymentResourcePoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.GetDeploymentResourcePoolRequest ): @@ -1097,8 +1112,8 @@ def sample_list_deployment_resource_pools(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1106,10 +1121,8 @@ def sample_list_deployment_resource_pools(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.ListDeploymentResourcePoolsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.ListDeploymentResourcePoolsRequest ): @@ -1239,8 +1252,8 @@ def sample_delete_deployment_resource_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1248,10 +1261,8 @@ def sample_delete_deployment_resource_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest, @@ -1369,8 +1380,8 @@ def sample_query_deployed_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([deployment_resource_pool]) if request is not None and has_flattened_params: raise ValueError( @@ -1378,10 +1389,8 @@ def sample_query_deployed_models(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a deployment_resource_pool_service.QueryDeployedModelsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, deployment_resource_pool_service.QueryDeployedModelsRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/grpc.py index 944e48d01c..7c1684d92c 100644 --- a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/grpc.py @@ -58,7 +58,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -78,14 +78,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -95,11 +98,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -126,7 +129,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -167,7 +170,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/grpc_asyncio.py index 28baa17d1b..611402290f 100644 --- a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -73,7 +75,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -103,7 +104,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -123,15 +124,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -141,11 +145,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -172,7 +176,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -212,7 +216,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -411,6 +417,36 @@ def query_deployed_models( ) return self._stubs["query_deployed_models"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_deployment_resource_pool: gapic_v1.method_async.wrap_method( + self.create_deployment_resource_pool, + default_timeout=None, + client_info=client_info, + ), + self.get_deployment_resource_pool: gapic_v1.method_async.wrap_method( + self.get_deployment_resource_pool, + default_timeout=None, + client_info=client_info, + ), + self.list_deployment_resource_pools: gapic_v1.method_async.wrap_method( + self.list_deployment_resource_pools, + default_timeout=None, + client_info=client_info, + ), + self.delete_deployment_resource_pool: gapic_v1.method_async.wrap_method( + self.delete_deployment_resource_pool, + default_timeout=None, + client_info=client_info, + ), + self.query_deployed_models: gapic_v1.method_async.wrap_method( + self.query_deployed_models, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/rest.py index 13d673bd97..20c7d72b0a 100644 --- a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/transports/rest.py @@ -802,10 +802,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1176,10 +1172,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1538,10 +1530,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1916,10 +1904,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2294,10 +2278,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -3755,10 +3735,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4186,10 +4162,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4608,10 +4580,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5047,10 +5015,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5486,10 +5450,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 4150d9e897..7af3b920b3 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -50,6 +51,7 @@ from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import service_networking from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -222,7 +224,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, EndpointServiceTransport, Callable[..., EndpointServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -234,9 +240,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.EndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,EndpointServiceTransport,Callable[..., EndpointServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the EndpointServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -385,8 +393,8 @@ async def sample_create_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint, endpoint_id]) if request is not None and has_flattened_params: raise ValueError( @@ -394,7 +402,10 @@ async def sample_create_endpoint(): "the individual field arguments should be set." ) - request = endpoint_service.CreateEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.CreateEndpointRequest): + request = endpoint_service.CreateEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -407,11 +418,9 @@ async def sample_create_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -503,8 +512,8 @@ async def sample_get_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -512,7 +521,10 @@ async def sample_get_endpoint(): "the individual field arguments should be set." ) - request = endpoint_service.GetEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.GetEndpointRequest): + request = endpoint_service.GetEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -521,11 +533,9 @@ async def sample_get_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -613,8 +623,8 @@ async def sample_list_endpoints(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -622,7 +632,10 @@ async def sample_list_endpoints(): "the individual field arguments should be set." ) - request = endpoint_service.ListEndpointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.ListEndpointsRequest): + request = endpoint_service.ListEndpointsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -631,11 +644,9 @@ async def sample_list_endpoints(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_endpoints, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_endpoints + ] # Certain fields should be provided within the metadata header; # add these here. @@ -739,8 +750,8 @@ async def sample_update_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -748,7 +759,10 @@ async def sample_update_endpoint(): "the individual field arguments should be set." ) - request = endpoint_service.UpdateEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.UpdateEndpointRequest): + request = endpoint_service.UpdateEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -759,11 +773,9 @@ async def sample_update_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -863,8 +875,8 @@ async def sample_delete_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -872,7 +884,10 @@ async def sample_delete_endpoint(): "the individual field arguments should be set." ) - request = endpoint_service.DeleteEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.DeleteEndpointRequest): + request = endpoint_service.DeleteEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -881,11 +896,9 @@ async def sample_delete_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1022,8 +1035,8 @@ async def sample_deploy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: raise ValueError( @@ -1031,7 +1044,10 @@ async def sample_deploy_model(): "the individual field arguments should be set." ) - request = endpoint_service.DeployModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.DeployModelRequest): + request = endpoint_service.DeployModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1045,11 +1061,9 @@ async def sample_deploy_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.deploy_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.deploy_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1173,8 +1187,8 @@ async def sample_undeploy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: raise ValueError( @@ -1182,7 +1196,10 @@ async def sample_undeploy_model(): "the individual field arguments should be set." ) - request = endpoint_service.UndeployModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.UndeployModelRequest): + request = endpoint_service.UndeployModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1196,11 +1213,9 @@ async def sample_undeploy_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.undeploy_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.undeploy_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1336,8 +1351,8 @@ async def sample_mutate_deployed_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1345,7 +1360,10 @@ async def sample_mutate_deployed_model(): "the individual field arguments should be set." ) - request = endpoint_service.MutateDeployedModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, endpoint_service.MutateDeployedModelRequest): + request = endpoint_service.MutateDeployedModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1358,11 +1376,9 @@ async def sample_mutate_deployed_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.mutate_deployed_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.mutate_deployed_model + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 961d917d8e..c1df1e76b0 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -55,6 +56,7 @@ from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import service_networking from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -626,7 +628,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, EndpointServiceTransport]] = None, + transport: Optional[ + Union[ + str, EndpointServiceTransport, Callable[..., EndpointServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -638,9 +644,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, EndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,EndpointServiceTransport,Callable[..., EndpointServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the EndpointServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -752,8 +760,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[EndpointServiceTransport], Callable[..., EndpointServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., EndpointServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -867,8 +882,8 @@ def sample_create_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint, endpoint_id]) if request is not None and has_flattened_params: raise ValueError( @@ -876,10 +891,8 @@ def sample_create_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.CreateEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.CreateEndpointRequest): request = endpoint_service.CreateEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -985,8 +998,8 @@ def sample_get_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -994,10 +1007,8 @@ def sample_get_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.GetEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.GetEndpointRequest): request = endpoint_service.GetEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1095,8 +1106,8 @@ def sample_list_endpoints(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1104,10 +1115,8 @@ def sample_list_endpoints(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.ListEndpointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.ListEndpointsRequest): request = endpoint_service.ListEndpointsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1221,8 +1230,8 @@ def sample_update_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1230,10 +1239,8 @@ def sample_update_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.UpdateEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.UpdateEndpointRequest): request = endpoint_service.UpdateEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1345,8 +1352,8 @@ def sample_delete_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1354,10 +1361,8 @@ def sample_delete_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.DeleteEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.DeleteEndpointRequest): request = endpoint_service.DeleteEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1504,8 +1509,8 @@ def sample_deploy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: raise ValueError( @@ -1513,10 +1518,8 @@ def sample_deploy_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.DeployModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.DeployModelRequest): request = endpoint_service.DeployModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1654,8 +1657,8 @@ def sample_undeploy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: raise ValueError( @@ -1663,10 +1666,8 @@ def sample_undeploy_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.UndeployModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.UndeployModelRequest): request = endpoint_service.UndeployModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1816,8 +1817,8 @@ def sample_mutate_deployed_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1825,10 +1826,8 @@ def sample_mutate_deployed_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.MutateDeployedModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, endpoint_service.MutateDeployedModelRequest): request = endpoint_service.MutateDeployedModelRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 3382fba675..7febbd993d 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index 79b551d9a6..33b088f277 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -486,6 +492,51 @@ def mutate_deployed_model( ) return self._stubs["mutate_deployed_model"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_endpoint: gapic_v1.method_async.wrap_method( + self.create_endpoint, + default_timeout=5.0, + client_info=client_info, + ), + self.get_endpoint: gapic_v1.method_async.wrap_method( + self.get_endpoint, + default_timeout=5.0, + client_info=client_info, + ), + self.list_endpoints: gapic_v1.method_async.wrap_method( + self.list_endpoints, + default_timeout=5.0, + client_info=client_info, + ), + self.update_endpoint: gapic_v1.method_async.wrap_method( + self.update_endpoint, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_endpoint: gapic_v1.method_async.wrap_method( + self.delete_endpoint, + default_timeout=5.0, + client_info=client_info, + ), + self.deploy_model: gapic_v1.method_async.wrap_method( + self.deploy_model, + default_timeout=5.0, + client_info=client_info, + ), + self.undeploy_model: gapic_v1.method_async.wrap_method( + self.undeploy_model, + default_timeout=5.0, + client_info=client_info, + ), + self.mutate_deployed_model: gapic_v1.method_async.wrap_method( + self.mutate_deployed_model, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/rest.py index 20dc716a71..cc9f2c2c34 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/rest.py @@ -874,10 +874,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1248,10 +1244,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1610,10 +1602,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1988,10 +1976,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2366,10 +2350,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4087,10 +4067,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4518,10 +4494,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4940,10 +4912,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5379,10 +5347,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5818,10 +5782,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/evaluation_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/evaluation_service/async_client.py index ef7a8abcc1..8b4f588e58 100644 --- a/google/cloud/aiplatform_v1beta1/services/evaluation_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/evaluation_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -194,7 +195,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, EvaluationServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + EvaluationServiceTransport, + Callable[..., EvaluationServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -206,9 +213,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.EvaluationServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,EvaluationServiceTransport,Callable[..., EvaluationServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the EvaluationServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -310,15 +319,16 @@ async def sample_evaluate_instances(): """ # Create or coerce a protobuf request object. - request = evaluation_service.EvaluateInstancesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, evaluation_service.EvaluateInstancesRequest): + request = evaluation_service.EvaluateInstancesRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.evaluate_instances, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.evaluate_instances + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/evaluation_service/client.py b/google/cloud/aiplatform_v1beta1/services/evaluation_service/client.py index deeec19140..15b3835333 100644 --- a/google/cloud/aiplatform_v1beta1/services/evaluation_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/evaluation_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -509,7 +510,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, EvaluationServiceTransport]] = None, + transport: Optional[ + Union[ + str, + EvaluationServiceTransport, + Callable[..., EvaluationServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -521,9 +528,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, EvaluationServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,EvaluationServiceTransport,Callable[..., EvaluationServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the EvaluationServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -635,8 +644,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[EvaluationServiceTransport], + Callable[..., EvaluationServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., EvaluationServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -703,10 +720,8 @@ def sample_evaluate_instances(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a evaluation_service.EvaluateInstancesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, evaluation_service.EvaluateInstancesRequest): request = evaluation_service.EvaluateInstancesRequest(request) diff --git a/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/grpc.py index 3095f9477d..9706d14174 100644 --- a/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/grpc.py @@ -54,7 +54,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -74,14 +74,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -91,11 +94,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -121,7 +124,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -162,7 +165,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/grpc_asyncio.py index c3b57a7bfb..62446757f9 100644 --- a/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -69,7 +71,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -99,7 +100,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -119,15 +120,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -137,11 +141,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -167,7 +171,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -207,7 +211,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -265,6 +271,16 @@ def evaluate_instances( ) return self._stubs["evaluate_instances"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.evaluate_instances: gapic_v1.method_async.wrap_method( + self.evaluate_instances, + default_timeout=60.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/rest.py index f52010d78a..fae67bfc91 100644 --- a/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/evaluation_service/transports/rest.py @@ -1308,10 +1308,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1739,10 +1735,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2161,10 +2153,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2600,10 +2588,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3039,10 +3023,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/async_client.py index c960dd6299..d429d3449b 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -214,7 +215,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ExtensionExecutionServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + ExtensionExecutionServiceTransport, + Callable[..., ExtensionExecutionServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -226,9 +233,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ExtensionExecutionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ExtensionExecutionServiceTransport,Callable[..., ExtensionExecutionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ExtensionExecutionServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -348,8 +357,8 @@ async def sample_execute_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, operation_id]) if request is not None and has_flattened_params: raise ValueError( @@ -357,7 +366,10 @@ async def sample_execute_extension(): "the individual field arguments should be set." ) - request = extension_execution_service.ExecuteExtensionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, extension_execution_service.ExecuteExtensionRequest): + request = extension_execution_service.ExecuteExtensionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -368,11 +380,9 @@ async def sample_execute_extension(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.execute_extension, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.execute_extension + ] # Certain fields should be provided within the metadata header; # add these here. @@ -473,8 +483,8 @@ async def sample_query_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -482,7 +492,10 @@ async def sample_query_extension(): "the individual field arguments should be set." ) - request = extension_execution_service.QueryExtensionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, extension_execution_service.QueryExtensionRequest): + request = extension_execution_service.QueryExtensionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -493,11 +506,9 @@ async def sample_query_extension(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_extension, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_extension + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/client.py b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/client.py index 8dbb33099e..e9b084532d 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -556,7 +557,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ExtensionExecutionServiceTransport]] = None, + transport: Optional[ + Union[ + str, + ExtensionExecutionServiceTransport, + Callable[..., ExtensionExecutionServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -568,9 +575,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ExtensionExecutionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ExtensionExecutionServiceTransport,Callable[..., ExtensionExecutionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ExtensionExecutionServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -684,8 +693,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ExtensionExecutionServiceTransport], + Callable[..., ExtensionExecutionServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ExtensionExecutionServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -770,8 +787,8 @@ def sample_execute_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, operation_id]) if request is not None and has_flattened_params: raise ValueError( @@ -779,10 +796,8 @@ def sample_execute_extension(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a extension_execution_service.ExecuteExtensionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, extension_execution_service.ExecuteExtensionRequest): request = extension_execution_service.ExecuteExtensionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -895,8 +910,8 @@ def sample_query_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -904,10 +919,8 @@ def sample_query_extension(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a extension_execution_service.QueryExtensionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, extension_execution_service.QueryExtensionRequest): request = extension_execution_service.QueryExtensionRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/grpc.py index a23531338b..a6b53a247a 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/grpc.py @@ -54,7 +54,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -74,14 +74,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -91,11 +94,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -121,7 +124,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -162,7 +165,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/grpc_asyncio.py index 15762bdd79..f062063cea 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -69,7 +71,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -99,7 +100,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -119,15 +120,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -137,11 +141,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -167,7 +171,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -207,7 +211,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -294,6 +300,21 @@ def query_extension( ) return self._stubs["query_extension"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.execute_extension: gapic_v1.method_async.wrap_method( + self.execute_extension, + default_timeout=None, + client_info=client_info, + ), + self.query_extension: gapic_v1.method_async.wrap_method( + self.query_extension, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/rest.py index 74adc9a5b4..7656512694 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_execution_service/transports/rest.py @@ -1447,10 +1447,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1878,10 +1874,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2300,10 +2292,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2739,10 +2727,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3178,10 +3162,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/async_client.py index 979799ffd2..e80d7d480f 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -223,7 +224,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ExtensionRegistryServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + ExtensionRegistryServiceTransport, + Callable[..., ExtensionRegistryServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -235,9 +242,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ExtensionRegistryServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ExtensionRegistryServiceTransport,Callable[..., ExtensionRegistryServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ExtensionRegistryServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -370,8 +379,8 @@ async def sample_import_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, extension]) if request is not None and has_flattened_params: raise ValueError( @@ -379,7 +388,10 @@ async def sample_import_extension(): "the individual field arguments should be set." ) - request = extension_registry_service.ImportExtensionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, extension_registry_service.ImportExtensionRequest): + request = extension_registry_service.ImportExtensionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -390,11 +402,9 @@ async def sample_import_extension(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_extension, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_extension + ] # Certain fields should be provided within the metadata header; # add these here. @@ -488,8 +498,8 @@ async def sample_get_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -497,7 +507,10 @@ async def sample_get_extension(): "the individual field arguments should be set." ) - request = extension_registry_service.GetExtensionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, extension_registry_service.GetExtensionRequest): + request = extension_registry_service.GetExtensionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -506,11 +519,9 @@ async def sample_get_extension(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_extension, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_extension + ] # Certain fields should be provided within the metadata header; # add these here. @@ -600,8 +611,8 @@ async def sample_list_extensions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -609,7 +620,10 @@ async def sample_list_extensions(): "the individual field arguments should be set." ) - request = extension_registry_service.ListExtensionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, extension_registry_service.ListExtensionsRequest): + request = extension_registry_service.ListExtensionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -618,11 +632,9 @@ async def sample_list_extensions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_extensions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_extensions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -740,8 +752,8 @@ async def sample_update_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([extension, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -749,7 +761,10 @@ async def sample_update_extension(): "the individual field arguments should be set." ) - request = extension_registry_service.UpdateExtensionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, extension_registry_service.UpdateExtensionRequest): + request = extension_registry_service.UpdateExtensionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -760,11 +775,9 @@ async def sample_update_extension(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_extension, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_extension + ] # Certain fields should be provided within the metadata header; # add these here. @@ -866,8 +879,8 @@ async def sample_delete_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -875,7 +888,10 @@ async def sample_delete_extension(): "the individual field arguments should be set." ) - request = extension_registry_service.DeleteExtensionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, extension_registry_service.DeleteExtensionRequest): + request = extension_registry_service.DeleteExtensionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -884,11 +900,9 @@ async def sample_delete_extension(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_extension, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_extension + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/client.py b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/client.py index acf469f12b..6f3998d70f 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -589,7 +590,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ExtensionRegistryServiceTransport]] = None, + transport: Optional[ + Union[ + str, + ExtensionRegistryServiceTransport, + Callable[..., ExtensionRegistryServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -601,9 +608,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ExtensionRegistryServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ExtensionRegistryServiceTransport,Callable[..., ExtensionRegistryServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ExtensionRegistryServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -717,8 +726,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ExtensionRegistryServiceTransport], + Callable[..., ExtensionRegistryServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ExtensionRegistryServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -816,8 +833,8 @@ def sample_import_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, extension]) if request is not None and has_flattened_params: raise ValueError( @@ -825,10 +842,8 @@ def sample_import_extension(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a extension_registry_service.ImportExtensionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, extension_registry_service.ImportExtensionRequest): request = extension_registry_service.ImportExtensionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -934,8 +949,8 @@ def sample_get_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -943,10 +958,8 @@ def sample_get_extension(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a extension_registry_service.GetExtensionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, extension_registry_service.GetExtensionRequest): request = extension_registry_service.GetExtensionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1046,8 +1059,8 @@ def sample_list_extensions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1055,10 +1068,8 @@ def sample_list_extensions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a extension_registry_service.ListExtensionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, extension_registry_service.ListExtensionsRequest): request = extension_registry_service.ListExtensionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1186,8 +1197,8 @@ def sample_update_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([extension, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1195,10 +1206,8 @@ def sample_update_extension(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a extension_registry_service.UpdateExtensionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, extension_registry_service.UpdateExtensionRequest): request = extension_registry_service.UpdateExtensionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1312,8 +1321,8 @@ def sample_delete_extension(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1321,10 +1330,8 @@ def sample_delete_extension(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a extension_registry_service.DeleteExtensionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, extension_registry_service.DeleteExtensionRequest): request = extension_registry_service.DeleteExtensionRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/grpc.py index 6fe0ebaba4..c0a668ec1f 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/grpc_asyncio.py index 3e5afe111a..029555b6ce 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -400,6 +406,36 @@ def delete_extension( ) return self._stubs["delete_extension"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.import_extension: gapic_v1.method_async.wrap_method( + self.import_extension, + default_timeout=None, + client_info=client_info, + ), + self.get_extension: gapic_v1.method_async.wrap_method( + self.get_extension, + default_timeout=None, + client_info=client_info, + ), + self.list_extensions: gapic_v1.method_async.wrap_method( + self.list_extensions, + default_timeout=None, + client_info=client_info, + ), + self.update_extension: gapic_v1.method_async.wrap_method( + self.update_extension, + default_timeout=None, + client_info=client_info, + ), + self.delete_extension: gapic_v1.method_async.wrap_method( + self.delete_extension, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/rest.py index fff925fb2a..18a8c58180 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/transports/rest.py @@ -791,10 +791,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1165,10 +1161,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1527,10 +1519,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1905,10 +1893,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2283,10 +2267,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -3715,10 +3695,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4146,10 +4122,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4568,10 +4540,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5007,10 +4975,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5446,10 +5410,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/async_client.py index d2ba4ff9b8..807e843d96 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -244,7 +245,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, FeatureOnlineStoreAdminServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + FeatureOnlineStoreAdminServiceTransport, + Callable[..., FeatureOnlineStoreAdminServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -256,9 +263,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeatureOnlineStoreAdminServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureOnlineStoreAdminServiceTransport,Callable[..., FeatureOnlineStoreAdminServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureOnlineStoreAdminServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -411,8 +420,8 @@ async def sample_create_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, feature_online_store, feature_online_store_id] ) @@ -422,9 +431,16 @@ async def sample_create_feature_online_store(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.CreateFeatureOnlineStoreRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.CreateFeatureOnlineStoreRequest + ): + request = ( + feature_online_store_admin_service.CreateFeatureOnlineStoreRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -437,11 +453,9 @@ async def sample_create_feature_online_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature_online_store, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature_online_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -537,8 +551,8 @@ async def sample_get_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -546,9 +560,14 @@ async def sample_get_feature_online_store(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.GetFeatureOnlineStoreRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.GetFeatureOnlineStoreRequest + ): + request = feature_online_store_admin_service.GetFeatureOnlineStoreRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -557,11 +576,9 @@ async def sample_get_feature_online_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature_online_store, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature_online_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -654,8 +671,8 @@ async def sample_list_feature_online_stores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -663,9 +680,14 @@ async def sample_list_feature_online_stores(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.ListFeatureOnlineStoresRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.ListFeatureOnlineStoresRequest + ): + request = feature_online_store_admin_service.ListFeatureOnlineStoresRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -674,11 +696,9 @@ async def sample_list_feature_online_stores(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_feature_online_stores, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_feature_online_stores + ] # Certain fields should be provided within the metadata header; # add these here. @@ -811,8 +831,8 @@ async def sample_update_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_online_store, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -820,9 +840,16 @@ async def sample_update_feature_online_store(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest + ): + request = ( + feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -833,11 +860,9 @@ async def sample_update_feature_online_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature_online_store, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature_online_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -961,8 +986,8 @@ async def sample_delete_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -970,9 +995,16 @@ async def sample_delete_feature_online_store(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest + ): + request = ( + feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest( + request + ) + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -983,11 +1015,9 @@ async def sample_delete_feature_online_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature_online_store, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature_online_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1115,8 +1145,8 @@ async def sample_create_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature_view, feature_view_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1124,7 +1154,14 @@ async def sample_create_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.CreateFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.CreateFeatureViewRequest + ): + request = feature_online_store_admin_service.CreateFeatureViewRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1137,11 +1174,9 @@ async def sample_create_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature_view, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1235,8 +1270,8 @@ async def sample_get_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1244,7 +1279,12 @@ async def sample_get_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.GetFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.GetFeatureViewRequest + ): + request = feature_online_store_admin_service.GetFeatureViewRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1253,11 +1293,9 @@ async def sample_get_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature_view, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1347,8 +1385,8 @@ async def sample_list_feature_views(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1356,7 +1394,14 @@ async def sample_list_feature_views(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.ListFeatureViewsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.ListFeatureViewsRequest + ): + request = feature_online_store_admin_service.ListFeatureViewsRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1365,11 +1410,9 @@ async def sample_list_feature_views(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_feature_views, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_feature_views + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1493,8 +1536,8 @@ async def sample_update_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1502,7 +1545,14 @@ async def sample_update_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.UpdateFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.UpdateFeatureViewRequest + ): + request = feature_online_store_admin_service.UpdateFeatureViewRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1513,11 +1563,9 @@ async def sample_update_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature_view, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1627,8 +1675,8 @@ async def sample_delete_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1636,7 +1684,14 @@ async def sample_delete_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.DeleteFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.DeleteFeatureViewRequest + ): + request = feature_online_store_admin_service.DeleteFeatureViewRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1645,11 +1700,9 @@ async def sample_delete_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature_view, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1742,8 +1795,8 @@ async def sample_sync_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view]) if request is not None and has_flattened_params: raise ValueError( @@ -1751,7 +1804,12 @@ async def sample_sync_feature_view(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.SyncFeatureViewRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.SyncFeatureViewRequest + ): + request = feature_online_store_admin_service.SyncFeatureViewRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1760,11 +1818,9 @@ async def sample_sync_feature_view(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.sync_feature_view, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.sync_feature_view + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1854,8 +1910,8 @@ async def sample_get_feature_view_sync(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1863,7 +1919,14 @@ async def sample_get_feature_view_sync(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.GetFeatureViewSyncRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.GetFeatureViewSyncRequest + ): + request = feature_online_store_admin_service.GetFeatureViewSyncRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1872,11 +1935,9 @@ async def sample_get_feature_view_sync(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature_view_sync, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature_view_sync + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1966,8 +2027,8 @@ async def sample_list_feature_view_syncs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1975,9 +2036,14 @@ async def sample_list_feature_view_syncs(): "the individual field arguments should be set." ) - request = feature_online_store_admin_service.ListFeatureViewSyncsRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_admin_service.ListFeatureViewSyncsRequest + ): + request = feature_online_store_admin_service.ListFeatureViewSyncsRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1986,11 +2052,9 @@ async def sample_list_feature_view_syncs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_feature_view_syncs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_feature_view_syncs + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/client.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/client.py index eb6131407f..58bd66a35f 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -608,7 +609,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, FeatureOnlineStoreAdminServiceTransport]] = None, + transport: Optional[ + Union[ + str, + FeatureOnlineStoreAdminServiceTransport, + Callable[..., FeatureOnlineStoreAdminServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -620,9 +627,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeatureOnlineStoreAdminServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureOnlineStoreAdminServiceTransport,Callable[..., FeatureOnlineStoreAdminServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureOnlineStoreAdminServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -740,8 +749,18 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeatureOnlineStoreAdminServiceTransport], + Callable[..., FeatureOnlineStoreAdminServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast( + Callable[..., FeatureOnlineStoreAdminServiceTransport], transport + ) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -859,8 +878,8 @@ def sample_create_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, feature_online_store, feature_online_store_id] ) @@ -870,10 +889,8 @@ def sample_create_feature_online_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.CreateFeatureOnlineStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.CreateFeatureOnlineStoreRequest ): @@ -991,8 +1008,8 @@ def sample_get_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1000,10 +1017,8 @@ def sample_get_feature_online_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.GetFeatureOnlineStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.GetFeatureOnlineStoreRequest ): @@ -1110,8 +1125,8 @@ def sample_list_feature_online_stores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1119,10 +1134,8 @@ def sample_list_feature_online_stores(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.ListFeatureOnlineStoresRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.ListFeatureOnlineStoresRequest ): @@ -1271,8 +1284,8 @@ def sample_update_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_online_store, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1280,10 +1293,8 @@ def sample_update_feature_online_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest ): @@ -1427,8 +1438,8 @@ def sample_delete_feature_online_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -1436,10 +1447,8 @@ def sample_delete_feature_online_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest ): @@ -1587,8 +1596,8 @@ def sample_create_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature_view, feature_view_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1596,10 +1605,8 @@ def sample_create_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.CreateFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.CreateFeatureViewRequest ): @@ -1711,8 +1718,8 @@ def sample_get_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1720,10 +1727,8 @@ def sample_get_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.GetFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.GetFeatureViewRequest ): @@ -1825,8 +1830,8 @@ def sample_list_feature_views(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1834,10 +1839,8 @@ def sample_list_feature_views(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.ListFeatureViewsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.ListFeatureViewsRequest ): @@ -1975,8 +1978,8 @@ def sample_update_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1984,10 +1987,8 @@ def sample_update_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.UpdateFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.UpdateFeatureViewRequest ): @@ -2113,8 +2114,8 @@ def sample_delete_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2122,10 +2123,8 @@ def sample_delete_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.DeleteFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.DeleteFeatureViewRequest ): @@ -2232,8 +2231,8 @@ def sample_sync_feature_view(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view]) if request is not None and has_flattened_params: raise ValueError( @@ -2241,10 +2240,8 @@ def sample_sync_feature_view(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.SyncFeatureViewRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.SyncFeatureViewRequest ): @@ -2346,8 +2343,8 @@ def sample_get_feature_view_sync(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2355,10 +2352,8 @@ def sample_get_feature_view_sync(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.GetFeatureViewSyncRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.GetFeatureViewSyncRequest ): @@ -2462,8 +2457,8 @@ def sample_list_feature_view_syncs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2471,10 +2466,8 @@ def sample_list_feature_view_syncs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_admin_service.ListFeatureViewSyncsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_admin_service.ListFeatureViewSyncsRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/grpc.py index 285f624d30..10ae9380f8 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/grpc.py @@ -61,7 +61,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -81,14 +81,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -98,11 +101,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -129,7 +132,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -170,7 +173,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/grpc_asyncio.py index c9a2ccf168..6ac8e900d1 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -76,7 +78,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -106,7 +107,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -126,15 +127,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -144,11 +148,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -175,7 +179,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -215,7 +219,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -642,6 +648,76 @@ def list_feature_view_syncs( ) return self._stubs["list_feature_view_syncs"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_feature_online_store: gapic_v1.method_async.wrap_method( + self.create_feature_online_store, + default_timeout=5.0, + client_info=client_info, + ), + self.get_feature_online_store: gapic_v1.method_async.wrap_method( + self.get_feature_online_store, + default_timeout=5.0, + client_info=client_info, + ), + self.list_feature_online_stores: gapic_v1.method_async.wrap_method( + self.list_feature_online_stores, + default_timeout=5.0, + client_info=client_info, + ), + self.update_feature_online_store: gapic_v1.method_async.wrap_method( + self.update_feature_online_store, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_feature_online_store: gapic_v1.method_async.wrap_method( + self.delete_feature_online_store, + default_timeout=5.0, + client_info=client_info, + ), + self.create_feature_view: gapic_v1.method_async.wrap_method( + self.create_feature_view, + default_timeout=5.0, + client_info=client_info, + ), + self.get_feature_view: gapic_v1.method_async.wrap_method( + self.get_feature_view, + default_timeout=5.0, + client_info=client_info, + ), + self.list_feature_views: gapic_v1.method_async.wrap_method( + self.list_feature_views, + default_timeout=5.0, + client_info=client_info, + ), + self.update_feature_view: gapic_v1.method_async.wrap_method( + self.update_feature_view, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_feature_view: gapic_v1.method_async.wrap_method( + self.delete_feature_view, + default_timeout=5.0, + client_info=client_info, + ), + self.sync_feature_view: gapic_v1.method_async.wrap_method( + self.sync_feature_view, + default_timeout=None, + client_info=client_info, + ), + self.get_feature_view_sync: gapic_v1.method_async.wrap_method( + self.get_feature_view_sync, + default_timeout=None, + client_info=client_info, + ), + self.list_feature_view_syncs: gapic_v1.method_async.wrap_method( + self.list_feature_view_syncs, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/rest.py index d35ee19662..6e1738fd11 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/transports/rest.py @@ -1077,10 +1077,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1451,10 +1447,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1813,10 +1805,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2191,10 +2179,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2569,10 +2553,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4876,10 +4856,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -5307,10 +5283,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -5729,10 +5701,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -6168,10 +6136,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -6607,10 +6571,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/async_client.py index 48a9793073..e65876ad28 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -212,7 +213,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, FeatureOnlineStoreServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + FeatureOnlineStoreServiceTransport, + Callable[..., FeatureOnlineStoreServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -224,9 +231,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeatureOnlineStoreServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureOnlineStoreServiceTransport,Callable[..., FeatureOnlineStoreServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureOnlineStoreServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -347,8 +356,8 @@ async def sample_fetch_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view, data_key]) if request is not None and has_flattened_params: raise ValueError( @@ -356,7 +365,12 @@ async def sample_fetch_feature_values(): "the individual field arguments should be set." ) - request = feature_online_store_service.FetchFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_service.FetchFeatureValuesRequest + ): + request = feature_online_store_service.FetchFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -367,11 +381,9 @@ async def sample_fetch_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.fetch_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.fetch_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -472,11 +484,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.streaming_fetch_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.streaming_fetch_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -558,15 +568,18 @@ async def sample_search_nearest_entities(): """ # Create or coerce a protobuf request object. - request = feature_online_store_service.SearchNearestEntitiesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, feature_online_store_service.SearchNearestEntitiesRequest + ): + request = feature_online_store_service.SearchNearestEntitiesRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_nearest_entities, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_nearest_entities + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/client.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/client.py index dd38f608cb..28d8332855 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -539,7 +540,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, FeatureOnlineStoreServiceTransport]] = None, + transport: Optional[ + Union[ + str, + FeatureOnlineStoreServiceTransport, + Callable[..., FeatureOnlineStoreServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -551,9 +558,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeatureOnlineStoreServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureOnlineStoreServiceTransport,Callable[..., FeatureOnlineStoreServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureOnlineStoreServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -667,8 +676,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeatureOnlineStoreServiceTransport], + Callable[..., FeatureOnlineStoreServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., FeatureOnlineStoreServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -754,8 +771,8 @@ def sample_fetch_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_view, data_key]) if request is not None and has_flattened_params: raise ValueError( @@ -763,10 +780,8 @@ def sample_fetch_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_service.FetchFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_service.FetchFeatureValuesRequest ): @@ -961,10 +976,8 @@ def sample_search_nearest_entities(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a feature_online_store_service.SearchNearestEntitiesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, feature_online_store_service.SearchNearestEntitiesRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/grpc.py index da2cc258bf..14ed444649 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/grpc.py @@ -54,7 +54,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -74,14 +74,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -91,11 +94,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -121,7 +124,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -162,7 +165,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/grpc_asyncio.py index c682a409de..d4d1177eaa 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -69,7 +71,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -99,7 +100,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -119,15 +120,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -137,11 +141,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -167,7 +171,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -207,7 +211,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -331,6 +337,26 @@ def search_nearest_entities( ) return self._stubs["search_nearest_entities"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.fetch_feature_values: gapic_v1.method_async.wrap_method( + self.fetch_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.streaming_fetch_feature_values: gapic_v1.method_async.wrap_method( + self.streaming_fetch_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.search_nearest_entities: gapic_v1.method_async.wrap_method( + self.search_nearest_entities, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/rest.py index 109f8a8156..79bc0c5d38 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_service/transports/rest.py @@ -1486,10 +1486,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1917,10 +1913,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2339,10 +2331,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2778,10 +2766,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3217,10 +3201,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py index a87dd3d61d..9990545bea 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -219,7 +220,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, FeatureRegistryServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + FeatureRegistryServiceTransport, + Callable[..., FeatureRegistryServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -231,9 +238,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeatureRegistryServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureRegistryServiceTransport,Callable[..., FeatureRegistryServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureRegistryServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -338,7 +347,7 @@ async def sample_create_feature_group(): parent (:class:`str`): Required. The resource name of the Location to create FeatureGroups. Format: - ``projects/{project}/locations/{location}'`` + ``projects/{project}/locations/{location}`` This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this @@ -379,8 +388,8 @@ async def sample_create_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature_group, feature_group_id]) if request is not None and has_flattened_params: raise ValueError( @@ -388,7 +397,10 @@ async def sample_create_feature_group(): "the individual field arguments should be set." ) - request = feature_registry_service.CreateFeatureGroupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.CreateFeatureGroupRequest): + request = feature_registry_service.CreateFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -401,11 +413,9 @@ async def sample_create_feature_group(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature_group, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature_group + ] # Certain fields should be provided within the metadata header; # add these here. @@ -496,8 +506,8 @@ async def sample_get_feature_group(): Vertex AI Feature Group. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -505,7 +515,10 @@ async def sample_get_feature_group(): "the individual field arguments should be set." ) - request = feature_registry_service.GetFeatureGroupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.GetFeatureGroupRequest): + request = feature_registry_service.GetFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -514,11 +527,9 @@ async def sample_get_feature_group(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature_group, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature_group + ] # Certain fields should be provided within the metadata header; # add these here. @@ -608,8 +619,8 @@ async def sample_list_feature_groups(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -617,7 +628,10 @@ async def sample_list_feature_groups(): "the individual field arguments should be set." ) - request = feature_registry_service.ListFeatureGroupsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.ListFeatureGroupsRequest): + request = feature_registry_service.ListFeatureGroupsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -626,11 +640,9 @@ async def sample_list_feature_groups(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_feature_groups, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_feature_groups + ] # Certain fields should be provided within the metadata header; # add these here. @@ -753,8 +765,8 @@ async def sample_update_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_group, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -762,7 +774,10 @@ async def sample_update_feature_group(): "the individual field arguments should be set." ) - request = feature_registry_service.UpdateFeatureGroupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.UpdateFeatureGroupRequest): + request = feature_registry_service.UpdateFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -773,11 +788,9 @@ async def sample_update_feature_group(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature_group, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature_group + ] # Certain fields should be provided within the metadata header; # add these here. @@ -897,8 +910,8 @@ async def sample_delete_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -906,7 +919,10 @@ async def sample_delete_feature_group(): "the individual field arguments should be set." ) - request = feature_registry_service.DeleteFeatureGroupRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, feature_registry_service.DeleteFeatureGroupRequest): + request = feature_registry_service.DeleteFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -917,11 +933,9 @@ async def sample_delete_feature_group(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature_group, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature_group + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1050,8 +1064,8 @@ async def sample_create_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1059,7 +1073,10 @@ async def sample_create_feature(): "the individual field arguments should be set." ) - request = featurestore_service.CreateFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.CreateFeatureRequest): + request = featurestore_service.CreateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1072,11 +1089,9 @@ async def sample_create_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1173,8 +1188,8 @@ async def sample_get_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1182,7 +1197,10 @@ async def sample_get_feature(): "the individual field arguments should be set." ) - request = featurestore_service.GetFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.GetFeatureRequest): + request = featurestore_service.GetFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1191,11 +1209,9 @@ async def sample_get_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1289,8 +1305,8 @@ async def sample_list_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1298,7 +1314,10 @@ async def sample_list_features(): "the individual field arguments should be set." ) - request = featurestore_service.ListFeaturesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ListFeaturesRequest): + request = featurestore_service.ListFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1307,11 +1326,9 @@ async def sample_list_features(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_features, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_features + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1436,8 +1453,8 @@ async def sample_update_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1445,7 +1462,10 @@ async def sample_update_feature(): "the individual field arguments should be set." ) - request = featurestore_service.UpdateFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.UpdateFeatureRequest): + request = featurestore_service.UpdateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1456,11 +1476,9 @@ async def sample_update_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1573,8 +1591,8 @@ async def sample_delete_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1582,7 +1600,10 @@ async def sample_delete_feature(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteFeatureRequest): + request = featurestore_service.DeleteFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1591,11 +1612,9 @@ async def sample_delete_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py index 6adeba848c..9c8db18f18 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -575,7 +576,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, FeatureRegistryServiceTransport]] = None, + transport: Optional[ + Union[ + str, + FeatureRegistryServiceTransport, + Callable[..., FeatureRegistryServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -587,9 +594,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeatureRegistryServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeatureRegistryServiceTransport,Callable[..., FeatureRegistryServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeatureRegistryServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -701,8 +710,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeatureRegistryServiceTransport], + Callable[..., FeatureRegistryServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., FeatureRegistryServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -772,7 +789,7 @@ def sample_create_feature_group(): parent (str): Required. The resource name of the Location to create FeatureGroups. Format: - ``projects/{project}/locations/{location}'`` + ``projects/{project}/locations/{location}`` This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this @@ -813,8 +830,8 @@ def sample_create_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature_group, feature_group_id]) if request is not None and has_flattened_params: raise ValueError( @@ -822,10 +839,8 @@ def sample_create_feature_group(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.CreateFeatureGroupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.CreateFeatureGroupRequest): request = feature_registry_service.CreateFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -930,8 +945,8 @@ def sample_get_feature_group(): Vertex AI Feature Group. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -939,10 +954,8 @@ def sample_get_feature_group(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.GetFeatureGroupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.GetFeatureGroupRequest): request = feature_registry_service.GetFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1042,8 +1055,8 @@ def sample_list_feature_groups(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1051,10 +1064,8 @@ def sample_list_feature_groups(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.ListFeatureGroupsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.ListFeatureGroupsRequest): request = feature_registry_service.ListFeatureGroupsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1187,8 +1198,8 @@ def sample_update_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature_group, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1196,10 +1207,8 @@ def sample_update_feature_group(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.UpdateFeatureGroupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.UpdateFeatureGroupRequest): request = feature_registry_service.UpdateFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1331,8 +1340,8 @@ def sample_delete_feature_group(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -1340,10 +1349,8 @@ def sample_delete_feature_group(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a feature_registry_service.DeleteFeatureGroupRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, feature_registry_service.DeleteFeatureGroupRequest): request = feature_registry_service.DeleteFeatureGroupRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1484,8 +1491,8 @@ def sample_create_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1493,10 +1500,8 @@ def sample_create_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.CreateFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.CreateFeatureRequest): request = featurestore_service.CreateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1607,8 +1612,8 @@ def sample_get_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1616,10 +1621,8 @@ def sample_get_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.GetFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.GetFeatureRequest): request = featurestore_service.GetFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1723,8 +1726,8 @@ def sample_list_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1732,10 +1735,8 @@ def sample_list_features(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ListFeaturesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ListFeaturesRequest): request = featurestore_service.ListFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1870,8 +1871,8 @@ def sample_update_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1879,10 +1880,8 @@ def sample_update_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.UpdateFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.UpdateFeatureRequest): request = featurestore_service.UpdateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2007,8 +2006,8 @@ def sample_delete_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2016,10 +2015,8 @@ def sample_delete_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteFeatureRequest): request = featurestore_service.DeleteFeatureRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/grpc.py index c786952c82..c5ba30658a 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/grpc.py @@ -59,7 +59,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -79,14 +79,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -96,11 +99,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -127,7 +130,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -168,7 +171,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/grpc_asyncio.py index 8905e76d4e..56f18c1a2b 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -74,7 +76,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -104,7 +105,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -124,15 +125,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -142,11 +146,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -173,7 +177,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -213,7 +217,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -543,6 +549,61 @@ def delete_feature( ) return self._stubs["delete_feature"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_feature_group: gapic_v1.method_async.wrap_method( + self.create_feature_group, + default_timeout=None, + client_info=client_info, + ), + self.get_feature_group: gapic_v1.method_async.wrap_method( + self.get_feature_group, + default_timeout=None, + client_info=client_info, + ), + self.list_feature_groups: gapic_v1.method_async.wrap_method( + self.list_feature_groups, + default_timeout=None, + client_info=client_info, + ), + self.update_feature_group: gapic_v1.method_async.wrap_method( + self.update_feature_group, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature_group: gapic_v1.method_async.wrap_method( + self.delete_feature_group, + default_timeout=None, + client_info=client_info, + ), + self.create_feature: gapic_v1.method_async.wrap_method( + self.create_feature, + default_timeout=None, + client_info=client_info, + ), + self.get_feature: gapic_v1.method_async.wrap_method( + self.get_feature, + default_timeout=None, + client_info=client_info, + ), + self.list_features: gapic_v1.method_async.wrap_method( + self.list_features, + default_timeout=None, + client_info=client_info, + ), + self.update_feature: gapic_v1.method_async.wrap_method( + self.update_feature, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature: gapic_v1.method_async.wrap_method( + self.delete_feature, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/rest.py index be73a557c9..e8435c45f2 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/transports/rest.py @@ -948,10 +948,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1322,10 +1318,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1684,10 +1676,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2062,10 +2050,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2440,10 +2424,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4376,10 +4356,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4807,10 +4783,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -5229,10 +5201,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5668,10 +5636,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -6107,10 +6071,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py index 95d4163982..28fa218975 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -216,8 +217,12 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[ - str, FeaturestoreOnlineServingServiceTransport + transport: Optional[ + Union[ + str, + FeaturestoreOnlineServingServiceTransport, + Callable[..., FeaturestoreOnlineServingServiceTransport], + ] ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, @@ -230,9 +235,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeaturestoreOnlineServingServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeaturestoreOnlineServingServiceTransport,Callable[..., FeaturestoreOnlineServingServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeaturestoreOnlineServingServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -354,8 +361,8 @@ async def sample_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -363,7 +370,12 @@ async def sample_read_feature_values(): "the individual field arguments should be set." ) - request = featurestore_online_service.ReadFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, featurestore_online_service.ReadFeatureValuesRequest + ): + request = featurestore_online_service.ReadFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -372,11 +384,9 @@ async def sample_read_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_feature_values, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -477,8 +487,8 @@ async def sample_streaming_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -486,7 +496,14 @@ async def sample_streaming_read_feature_values(): "the individual field arguments should be set." ) - request = featurestore_online_service.StreamingReadFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, featurestore_online_service.StreamingReadFeatureValuesRequest + ): + request = featurestore_online_service.StreamingReadFeatureValuesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -495,11 +512,9 @@ async def sample_streaming_read_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.streaming_read_feature_values, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.streaming_read_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -608,8 +623,8 @@ async def sample_write_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, payloads]) if request is not None and has_flattened_params: raise ValueError( @@ -617,7 +632,12 @@ async def sample_write_feature_values(): "the individual field arguments should be set." ) - request = featurestore_online_service.WriteFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, featurestore_online_service.WriteFeatureValuesRequest + ): + request = featurestore_online_service.WriteFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -628,11 +648,9 @@ async def sample_write_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.write_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.write_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py index 43bcb6ff54..875e30a4c0 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -544,7 +545,11 @@ def __init__( *, credentials: Optional[ga_credentials.Credentials] = None, transport: Optional[ - Union[str, FeaturestoreOnlineServingServiceTransport] + Union[ + str, + FeaturestoreOnlineServingServiceTransport, + Callable[..., FeaturestoreOnlineServingServiceTransport], + ] ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, @@ -557,9 +562,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeaturestoreOnlineServingServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeaturestoreOnlineServingServiceTransport,Callable[..., FeaturestoreOnlineServingServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeaturestoreOnlineServingServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -677,8 +684,18 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeaturestoreOnlineServingServiceTransport], + Callable[..., FeaturestoreOnlineServingServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast( + Callable[..., FeaturestoreOnlineServingServiceTransport], transport + ) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -765,8 +782,8 @@ def sample_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -774,10 +791,8 @@ def sample_read_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_online_service.ReadFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, featurestore_online_service.ReadFeatureValuesRequest ): @@ -888,8 +903,8 @@ def sample_streaming_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -897,10 +912,8 @@ def sample_streaming_read_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_online_service.StreamingReadFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, featurestore_online_service.StreamingReadFeatureValuesRequest ): @@ -1025,8 +1038,8 @@ def sample_write_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, payloads]) if request is not None and has_flattened_params: raise ValueError( @@ -1034,10 +1047,8 @@ def sample_write_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_online_service.WriteFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, featurestore_online_service.WriteFeatureValuesRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py index 72745711b1..aa437239a9 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py @@ -56,7 +56,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -76,14 +76,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -93,11 +96,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -123,7 +126,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -164,7 +167,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py index fdb95defa1..125a1f3129 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -71,7 +73,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -101,7 +102,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -121,15 +122,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -139,11 +143,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -169,7 +173,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -209,7 +213,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -336,6 +342,26 @@ def write_feature_values( ) return self._stubs["write_feature_values"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.read_feature_values: gapic_v1.method_async.wrap_method( + self.read_feature_values, + default_timeout=5.0, + client_info=client_info, + ), + self.streaming_read_feature_values: gapic_v1.method_async.wrap_method( + self.streaming_read_feature_values, + default_timeout=5.0, + client_info=client_info, + ), + self.write_feature_values: gapic_v1.method_async.wrap_method( + self.write_feature_values, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/rest.py index f4e1f8aa7a..5fb4c891d0 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/rest.py @@ -1601,10 +1601,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -2032,10 +2028,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2454,10 +2446,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2893,10 +2881,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3332,10 +3316,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py index 7190186fa1..4bec282b54 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -223,7 +224,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, FeaturestoreServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + FeaturestoreServiceTransport, + Callable[..., FeaturestoreServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -235,9 +242,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.FeaturestoreServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeaturestoreServiceTransport,Callable[..., FeaturestoreServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeaturestoreServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -380,8 +389,8 @@ async def sample_create_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore, featurestore_id]) if request is not None and has_flattened_params: raise ValueError( @@ -389,7 +398,10 @@ async def sample_create_featurestore(): "the individual field arguments should be set." ) - request = featurestore_service.CreateFeaturestoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.CreateFeaturestoreRequest): + request = featurestore_service.CreateFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -402,11 +414,9 @@ async def sample_create_featurestore(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_featurestore, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_featurestore + ] # Certain fields should be provided within the metadata header; # add these here. @@ -502,8 +512,8 @@ async def sample_get_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -511,7 +521,10 @@ async def sample_get_featurestore(): "the individual field arguments should be set." ) - request = featurestore_service.GetFeaturestoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.GetFeaturestoreRequest): + request = featurestore_service.GetFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -520,11 +533,9 @@ async def sample_get_featurestore(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_featurestore, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_featurestore + ] # Certain fields should be provided within the metadata header; # add these here. @@ -614,8 +625,8 @@ async def sample_list_featurestores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -623,7 +634,10 @@ async def sample_list_featurestores(): "the individual field arguments should be set." ) - request = featurestore_service.ListFeaturestoresRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ListFeaturestoresRequest): + request = featurestore_service.ListFeaturestoresRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -632,11 +646,9 @@ async def sample_list_featurestores(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_featurestores, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_featurestores + ] # Certain fields should be provided within the metadata header; # add these here. @@ -759,8 +771,8 @@ async def sample_update_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -768,7 +780,10 @@ async def sample_update_featurestore(): "the individual field arguments should be set." ) - request = featurestore_service.UpdateFeaturestoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.UpdateFeaturestoreRequest): + request = featurestore_service.UpdateFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -779,11 +794,9 @@ async def sample_update_featurestore(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_featurestore, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_featurestore + ] # Certain fields should be provided within the metadata header; # add these here. @@ -906,8 +919,8 @@ async def sample_delete_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -915,7 +928,10 @@ async def sample_delete_featurestore(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteFeaturestoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteFeaturestoreRequest): + request = featurestore_service.DeleteFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -926,11 +942,9 @@ async def sample_delete_featurestore(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_featurestore, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_featurestore + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1054,8 +1068,8 @@ async def sample_create_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type, entity_type_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1063,7 +1077,10 @@ async def sample_create_entity_type(): "the individual field arguments should be set." ) - request = featurestore_service.CreateEntityTypeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.CreateEntityTypeRequest): + request = featurestore_service.CreateEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1076,11 +1093,9 @@ async def sample_create_entity_type(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_entity_type, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_entity_type + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1177,8 +1192,8 @@ async def sample_get_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1186,7 +1201,10 @@ async def sample_get_entity_type(): "the individual field arguments should be set." ) - request = featurestore_service.GetEntityTypeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.GetEntityTypeRequest): + request = featurestore_service.GetEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1195,11 +1213,9 @@ async def sample_get_entity_type(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_entity_type, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_entity_type + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1289,8 +1305,8 @@ async def sample_list_entity_types(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1298,7 +1314,10 @@ async def sample_list_entity_types(): "the individual field arguments should be set." ) - request = featurestore_service.ListEntityTypesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ListEntityTypesRequest): + request = featurestore_service.ListEntityTypesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1307,11 +1326,9 @@ async def sample_list_entity_types(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_entity_types, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_entity_types + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1436,8 +1453,8 @@ async def sample_update_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1445,7 +1462,10 @@ async def sample_update_entity_type(): "the individual field arguments should be set." ) - request = featurestore_service.UpdateEntityTypeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.UpdateEntityTypeRequest): + request = featurestore_service.UpdateEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1456,11 +1476,9 @@ async def sample_update_entity_type(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_entity_type, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_entity_type + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1574,8 +1592,8 @@ async def sample_delete_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -1583,7 +1601,10 @@ async def sample_delete_entity_type(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteEntityTypeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteEntityTypeRequest): + request = featurestore_service.DeleteEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1594,11 +1615,9 @@ async def sample_delete_entity_type(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_entity_type, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_entity_type + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1727,8 +1746,8 @@ async def sample_create_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1736,7 +1755,10 @@ async def sample_create_feature(): "the individual field arguments should be set." ) - request = featurestore_service.CreateFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.CreateFeatureRequest): + request = featurestore_service.CreateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1749,11 +1771,9 @@ async def sample_create_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_feature, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1872,8 +1892,8 @@ async def sample_batch_create_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -1881,7 +1901,10 @@ async def sample_batch_create_features(): "the individual field arguments should be set." ) - request = featurestore_service.BatchCreateFeaturesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.BatchCreateFeaturesRequest): + request = featurestore_service.BatchCreateFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1892,11 +1915,9 @@ async def sample_batch_create_features(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_create_features, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_create_features + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1993,8 +2014,8 @@ async def sample_get_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2002,7 +2023,10 @@ async def sample_get_feature(): "the individual field arguments should be set." ) - request = featurestore_service.GetFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.GetFeatureRequest): + request = featurestore_service.GetFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2011,11 +2035,9 @@ async def sample_get_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_feature, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2109,8 +2131,8 @@ async def sample_list_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2118,7 +2140,10 @@ async def sample_list_features(): "the individual field arguments should be set." ) - request = featurestore_service.ListFeaturesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ListFeaturesRequest): + request = featurestore_service.ListFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2127,11 +2152,9 @@ async def sample_list_features(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_features, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_features + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2250,8 +2273,8 @@ async def sample_update_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2259,7 +2282,10 @@ async def sample_update_feature(): "the individual field arguments should be set." ) - request = featurestore_service.UpdateFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.UpdateFeatureRequest): + request = featurestore_service.UpdateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2270,11 +2296,9 @@ async def sample_update_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_feature, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2379,8 +2403,8 @@ async def sample_delete_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2388,7 +2412,10 @@ async def sample_delete_feature(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteFeatureRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteFeatureRequest): + request = featurestore_service.DeleteFeatureRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2397,11 +2424,9 @@ async def sample_delete_feature(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2535,8 +2560,8 @@ async def sample_import_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -2544,7 +2569,10 @@ async def sample_import_feature_values(): "the individual field arguments should be set." ) - request = featurestore_service.ImportFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ImportFeatureValuesRequest): + request = featurestore_service.ImportFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2553,11 +2581,9 @@ async def sample_import_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_feature_values, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2678,8 +2704,8 @@ async def sample_batch_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: raise ValueError( @@ -2687,7 +2713,10 @@ async def sample_batch_read_feature_values(): "the individual field arguments should be set." ) - request = featurestore_service.BatchReadFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.BatchReadFeatureValuesRequest): + request = featurestore_service.BatchReadFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2696,11 +2725,9 @@ async def sample_batch_read_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_read_feature_values, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_read_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2811,8 +2838,8 @@ async def sample_export_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -2820,7 +2847,10 @@ async def sample_export_feature_values(): "the individual field arguments should be set." ) - request = featurestore_service.ExportFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.ExportFeatureValuesRequest): + request = featurestore_service.ExportFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2829,11 +2859,9 @@ async def sample_export_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2952,8 +2980,8 @@ async def sample_delete_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -2961,7 +2989,10 @@ async def sample_delete_feature_values(): "the individual field arguments should be set." ) - request = featurestore_service.DeleteFeatureValuesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.DeleteFeatureValuesRequest): + request = featurestore_service.DeleteFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2970,11 +3001,9 @@ async def sample_delete_feature_values(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_feature_values, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_feature_values + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3151,8 +3180,8 @@ async def sample_search_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([location, query]) if request is not None and has_flattened_params: raise ValueError( @@ -3160,7 +3189,10 @@ async def sample_search_features(): "the individual field arguments should be set." ) - request = featurestore_service.SearchFeaturesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, featurestore_service.SearchFeaturesRequest): + request = featurestore_service.SearchFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3171,11 +3203,9 @@ async def sample_search_features(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_features, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_features + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py index 2f86021285..c0d3d706cc 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -599,7 +600,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, FeaturestoreServiceTransport]] = None, + transport: Optional[ + Union[ + str, + FeaturestoreServiceTransport, + Callable[..., FeaturestoreServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -611,9 +618,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, FeaturestoreServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,FeaturestoreServiceTransport,Callable[..., FeaturestoreServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FeaturestoreServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -725,8 +734,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[FeaturestoreServiceTransport], + Callable[..., FeaturestoreServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., FeaturestoreServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -834,8 +851,8 @@ def sample_create_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, featurestore, featurestore_id]) if request is not None and has_flattened_params: raise ValueError( @@ -843,10 +860,8 @@ def sample_create_featurestore(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.CreateFeaturestoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.CreateFeaturestoreRequest): request = featurestore_service.CreateFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -956,8 +971,8 @@ def sample_get_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -965,10 +980,8 @@ def sample_get_featurestore(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.GetFeaturestoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.GetFeaturestoreRequest): request = featurestore_service.GetFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1068,8 +1081,8 @@ def sample_list_featurestores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1077,10 +1090,8 @@ def sample_list_featurestores(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ListFeaturestoresRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ListFeaturestoresRequest): request = featurestore_service.ListFeaturestoresRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1213,8 +1224,8 @@ def sample_update_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1222,10 +1233,8 @@ def sample_update_featurestore(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.UpdateFeaturestoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.UpdateFeaturestoreRequest): request = featurestore_service.UpdateFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1360,8 +1369,8 @@ def sample_delete_featurestore(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -1369,10 +1378,8 @@ def sample_delete_featurestore(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteFeaturestoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteFeaturestoreRequest): request = featurestore_service.DeleteFeaturestoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1508,8 +1515,8 @@ def sample_create_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, entity_type, entity_type_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1517,10 +1524,8 @@ def sample_create_entity_type(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.CreateEntityTypeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.CreateEntityTypeRequest): request = featurestore_service.CreateEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1631,8 +1636,8 @@ def sample_get_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1640,10 +1645,8 @@ def sample_get_entity_type(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.GetEntityTypeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.GetEntityTypeRequest): request = featurestore_service.GetEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1743,8 +1746,8 @@ def sample_list_entity_types(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1752,10 +1755,8 @@ def sample_list_entity_types(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ListEntityTypesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ListEntityTypesRequest): request = featurestore_service.ListEntityTypesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1890,8 +1891,8 @@ def sample_update_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1899,10 +1900,8 @@ def sample_update_entity_type(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.UpdateEntityTypeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.UpdateEntityTypeRequest): request = featurestore_service.UpdateEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2028,8 +2027,8 @@ def sample_delete_entity_type(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, force]) if request is not None and has_flattened_params: raise ValueError( @@ -2037,10 +2036,8 @@ def sample_delete_entity_type(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteEntityTypeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteEntityTypeRequest): request = featurestore_service.DeleteEntityTypeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2181,8 +2178,8 @@ def sample_create_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, feature, feature_id]) if request is not None and has_flattened_params: raise ValueError( @@ -2190,10 +2187,8 @@ def sample_create_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.CreateFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.CreateFeatureRequest): request = featurestore_service.CreateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2326,8 +2321,8 @@ def sample_batch_create_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -2335,10 +2330,8 @@ def sample_batch_create_features(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.BatchCreateFeaturesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.BatchCreateFeaturesRequest): request = featurestore_service.BatchCreateFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2447,8 +2440,8 @@ def sample_get_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2456,10 +2449,8 @@ def sample_get_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.GetFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.GetFeatureRequest): request = featurestore_service.GetFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2563,8 +2554,8 @@ def sample_list_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2572,10 +2563,8 @@ def sample_list_features(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ListFeaturesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ListFeaturesRequest): request = featurestore_service.ListFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2704,8 +2693,8 @@ def sample_update_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([feature, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2713,10 +2702,8 @@ def sample_update_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.UpdateFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.UpdateFeatureRequest): request = featurestore_service.UpdateFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2833,8 +2820,8 @@ def sample_delete_feature(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2842,10 +2829,8 @@ def sample_delete_feature(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteFeatureRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteFeatureRequest): request = featurestore_service.DeleteFeatureRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2989,8 +2974,8 @@ def sample_import_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -2998,10 +2983,8 @@ def sample_import_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ImportFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ImportFeatureValuesRequest): request = featurestore_service.ImportFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3132,8 +3115,8 @@ def sample_batch_read_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([featurestore]) if request is not None and has_flattened_params: raise ValueError( @@ -3141,10 +3124,8 @@ def sample_batch_read_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.BatchReadFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.BatchReadFeatureValuesRequest): request = featurestore_service.BatchReadFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3267,8 +3248,8 @@ def sample_export_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -3276,10 +3257,8 @@ def sample_export_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.ExportFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.ExportFeatureValuesRequest): request = featurestore_service.ExportFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3408,8 +3387,8 @@ def sample_delete_feature_values(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([entity_type]) if request is not None and has_flattened_params: raise ValueError( @@ -3417,10 +3396,8 @@ def sample_delete_feature_values(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.DeleteFeatureValuesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.DeleteFeatureValuesRequest): request = featurestore_service.DeleteFeatureValuesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3607,8 +3584,8 @@ def sample_search_features(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([location, query]) if request is not None and has_flattened_params: raise ValueError( @@ -3616,10 +3593,8 @@ def sample_search_features(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a featurestore_service.SearchFeaturesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, featurestore_service.SearchFeaturesRequest): request = featurestore_service.SearchFeaturesRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py index b8a1bd9857..9fd42824a3 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py @@ -61,7 +61,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -81,14 +81,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -98,11 +101,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -129,7 +132,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -170,7 +173,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py index aacbe9e83d..b5fac80296 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -76,7 +78,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -106,7 +107,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -126,15 +127,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -144,11 +148,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -175,7 +179,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -215,7 +219,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -911,6 +917,116 @@ def search_features( ) return self._stubs["search_features"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_featurestore: gapic_v1.method_async.wrap_method( + self.create_featurestore, + default_timeout=5.0, + client_info=client_info, + ), + self.get_featurestore: gapic_v1.method_async.wrap_method( + self.get_featurestore, + default_timeout=5.0, + client_info=client_info, + ), + self.list_featurestores: gapic_v1.method_async.wrap_method( + self.list_featurestores, + default_timeout=5.0, + client_info=client_info, + ), + self.update_featurestore: gapic_v1.method_async.wrap_method( + self.update_featurestore, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_featurestore: gapic_v1.method_async.wrap_method( + self.delete_featurestore, + default_timeout=5.0, + client_info=client_info, + ), + self.create_entity_type: gapic_v1.method_async.wrap_method( + self.create_entity_type, + default_timeout=5.0, + client_info=client_info, + ), + self.get_entity_type: gapic_v1.method_async.wrap_method( + self.get_entity_type, + default_timeout=5.0, + client_info=client_info, + ), + self.list_entity_types: gapic_v1.method_async.wrap_method( + self.list_entity_types, + default_timeout=5.0, + client_info=client_info, + ), + self.update_entity_type: gapic_v1.method_async.wrap_method( + self.update_entity_type, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_entity_type: gapic_v1.method_async.wrap_method( + self.delete_entity_type, + default_timeout=5.0, + client_info=client_info, + ), + self.create_feature: gapic_v1.method_async.wrap_method( + self.create_feature, + default_timeout=5.0, + client_info=client_info, + ), + self.batch_create_features: gapic_v1.method_async.wrap_method( + self.batch_create_features, + default_timeout=5.0, + client_info=client_info, + ), + self.get_feature: gapic_v1.method_async.wrap_method( + self.get_feature, + default_timeout=5.0, + client_info=client_info, + ), + self.list_features: gapic_v1.method_async.wrap_method( + self.list_features, + default_timeout=5.0, + client_info=client_info, + ), + self.update_feature: gapic_v1.method_async.wrap_method( + self.update_feature, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_feature: gapic_v1.method_async.wrap_method( + self.delete_feature, + default_timeout=5.0, + client_info=client_info, + ), + self.import_feature_values: gapic_v1.method_async.wrap_method( + self.import_feature_values, + default_timeout=5.0, + client_info=client_info, + ), + self.batch_read_feature_values: gapic_v1.method_async.wrap_method( + self.batch_read_feature_values, + default_timeout=5.0, + client_info=client_info, + ), + self.export_feature_values: gapic_v1.method_async.wrap_method( + self.export_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.delete_feature_values: gapic_v1.method_async.wrap_method( + self.delete_feature_values, + default_timeout=None, + client_info=client_info, + ), + self.search_features: gapic_v1.method_async.wrap_method( + self.search_features, + default_timeout=5.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/rest.py index 0aa3f8bf2c..7599657770 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/rest.py @@ -1297,10 +1297,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1671,10 +1667,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2033,10 +2025,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2411,10 +2399,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2789,10 +2773,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -5856,10 +5836,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -6287,10 +6263,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -6709,10 +6681,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -7148,10 +7116,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -7587,10 +7551,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py index 072f54d26d..268bd9bca8 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -212,7 +213,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, IndexEndpointServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + IndexEndpointServiceTransport, + Callable[..., IndexEndpointServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -224,9 +231,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.IndexEndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,IndexEndpointServiceTransport,Callable[..., IndexEndpointServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the IndexEndpointServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -355,8 +364,8 @@ async def sample_create_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: raise ValueError( @@ -364,7 +373,10 @@ async def sample_create_index_endpoint(): "the individual field arguments should be set." ) - request = index_endpoint_service.CreateIndexEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.CreateIndexEndpointRequest): + request = index_endpoint_service.CreateIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -375,11 +387,9 @@ async def sample_create_index_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_index_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_index_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -474,8 +484,8 @@ async def sample_get_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -483,7 +493,10 @@ async def sample_get_index_endpoint(): "the individual field arguments should be set." ) - request = index_endpoint_service.GetIndexEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.GetIndexEndpointRequest): + request = index_endpoint_service.GetIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -492,11 +505,9 @@ async def sample_get_index_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_index_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_index_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -586,8 +597,8 @@ async def sample_list_index_endpoints(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -595,7 +606,10 @@ async def sample_list_index_endpoints(): "the individual field arguments should be set." ) - request = index_endpoint_service.ListIndexEndpointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.ListIndexEndpointsRequest): + request = index_endpoint_service.ListIndexEndpointsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -604,11 +618,9 @@ async def sample_list_index_endpoints(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_index_endpoints, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_index_endpoints + ] # Certain fields should be provided within the metadata header; # add these here. @@ -714,8 +726,8 @@ async def sample_update_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -723,7 +735,10 @@ async def sample_update_index_endpoint(): "the individual field arguments should be set." ) - request = index_endpoint_service.UpdateIndexEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.UpdateIndexEndpointRequest): + request = index_endpoint_service.UpdateIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -734,11 +749,9 @@ async def sample_update_index_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_index_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_index_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -840,8 +853,8 @@ async def sample_delete_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -849,7 +862,10 @@ async def sample_delete_index_endpoint(): "the individual field arguments should be set." ) - request = index_endpoint_service.DeleteIndexEndpointRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.DeleteIndexEndpointRequest): + request = index_endpoint_service.DeleteIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -858,11 +874,9 @@ async def sample_delete_index_endpoint(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_index_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_index_endpoint + ] # Certain fields should be provided within the metadata header; # add these here. @@ -977,8 +991,8 @@ async def sample_deploy_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: raise ValueError( @@ -986,7 +1000,10 @@ async def sample_deploy_index(): "the individual field arguments should be set." ) - request = index_endpoint_service.DeployIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.DeployIndexRequest): + request = index_endpoint_service.DeployIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -997,11 +1014,9 @@ async def sample_deploy_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.deploy_index, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.deploy_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1114,8 +1129,8 @@ async def sample_undeploy_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1123,7 +1138,10 @@ async def sample_undeploy_index(): "the individual field arguments should be set." ) - request = index_endpoint_service.UndeployIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.UndeployIndexRequest): + request = index_endpoint_service.UndeployIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1134,11 +1152,9 @@ async def sample_undeploy_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.undeploy_index, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.undeploy_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1256,8 +1272,8 @@ async def sample_mutate_deployed_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: raise ValueError( @@ -1265,7 +1281,10 @@ async def sample_mutate_deployed_index(): "the individual field arguments should be set." ) - request = index_endpoint_service.MutateDeployedIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_endpoint_service.MutateDeployedIndexRequest): + request = index_endpoint_service.MutateDeployedIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1276,11 +1295,9 @@ async def sample_mutate_deployed_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.mutate_deployed_index, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.mutate_deployed_index + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py index 196d94f106..f3e602ff42 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -564,7 +565,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, IndexEndpointServiceTransport]] = None, + transport: Optional[ + Union[ + str, + IndexEndpointServiceTransport, + Callable[..., IndexEndpointServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -576,9 +583,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, IndexEndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,IndexEndpointServiceTransport,Callable[..., IndexEndpointServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the IndexEndpointServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -690,8 +699,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[IndexEndpointServiceTransport], + Callable[..., IndexEndpointServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., IndexEndpointServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -785,8 +802,8 @@ def sample_create_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index_endpoint]) if request is not None and has_flattened_params: raise ValueError( @@ -794,10 +811,8 @@ def sample_create_index_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.CreateIndexEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.CreateIndexEndpointRequest): request = index_endpoint_service.CreateIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -904,8 +919,8 @@ def sample_get_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -913,10 +928,8 @@ def sample_get_index_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.GetIndexEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.GetIndexEndpointRequest): request = index_endpoint_service.GetIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1016,8 +1029,8 @@ def sample_list_index_endpoints(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1025,10 +1038,8 @@ def sample_list_index_endpoints(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.ListIndexEndpointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.ListIndexEndpointsRequest): request = index_endpoint_service.ListIndexEndpointsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1144,8 +1155,8 @@ def sample_update_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1153,10 +1164,8 @@ def sample_update_index_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.UpdateIndexEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.UpdateIndexEndpointRequest): request = index_endpoint_service.UpdateIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1270,8 +1279,8 @@ def sample_delete_index_endpoint(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1279,10 +1288,8 @@ def sample_delete_index_endpoint(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.DeleteIndexEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.DeleteIndexEndpointRequest): request = index_endpoint_service.DeleteIndexEndpointRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1407,8 +1414,8 @@ def sample_deploy_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: raise ValueError( @@ -1416,10 +1423,8 @@ def sample_deploy_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.DeployIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.DeployIndexRequest): request = index_endpoint_service.DeployIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1544,8 +1549,8 @@ def sample_undeploy_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1553,10 +1558,8 @@ def sample_undeploy_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.UndeployIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.UndeployIndexRequest): request = index_endpoint_service.UndeployIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1686,8 +1689,8 @@ def sample_mutate_deployed_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index_endpoint, deployed_index]) if request is not None and has_flattened_params: raise ValueError( @@ -1695,10 +1698,8 @@ def sample_mutate_deployed_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_endpoint_service.MutateDeployedIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_endpoint_service.MutateDeployedIndexRequest): request = index_endpoint_service.MutateDeployedIndexRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py index b16b66a050..9607c4d526 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py index 3088c84e11..4799a1e121 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -492,6 +498,51 @@ def mutate_deployed_index( ) return self._stubs["mutate_deployed_index"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_index_endpoint: gapic_v1.method_async.wrap_method( + self.create_index_endpoint, + default_timeout=5.0, + client_info=client_info, + ), + self.get_index_endpoint: gapic_v1.method_async.wrap_method( + self.get_index_endpoint, + default_timeout=5.0, + client_info=client_info, + ), + self.list_index_endpoints: gapic_v1.method_async.wrap_method( + self.list_index_endpoints, + default_timeout=5.0, + client_info=client_info, + ), + self.update_index_endpoint: gapic_v1.method_async.wrap_method( + self.update_index_endpoint, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_index_endpoint: gapic_v1.method_async.wrap_method( + self.delete_index_endpoint, + default_timeout=5.0, + client_info=client_info, + ), + self.deploy_index: gapic_v1.method_async.wrap_method( + self.deploy_index, + default_timeout=5.0, + client_info=client_info, + ), + self.undeploy_index: gapic_v1.method_async.wrap_method( + self.undeploy_index, + default_timeout=5.0, + client_info=client_info, + ), + self.mutate_deployed_index: gapic_v1.method_async.wrap_method( + self.mutate_deployed_index, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/rest.py index 46d1bb8b0c..eb0d4d346f 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/rest.py @@ -888,10 +888,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1262,10 +1258,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1624,10 +1616,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2002,10 +1990,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2380,10 +2364,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4125,10 +4105,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4556,10 +4532,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4978,10 +4950,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5417,10 +5385,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5856,10 +5820,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py index 938829857b..58617fa0ef 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -210,7 +211,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, IndexServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, IndexServiceTransport, Callable[..., IndexServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -222,9 +225,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.IndexServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,IndexServiceTransport,Callable[..., IndexServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the IndexServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -350,8 +355,8 @@ async def sample_create_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: raise ValueError( @@ -359,7 +364,10 @@ async def sample_create_index(): "the individual field arguments should be set." ) - request = index_service.CreateIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.CreateIndexRequest): + request = index_service.CreateIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -370,11 +378,9 @@ async def sample_create_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_index, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -467,8 +473,8 @@ async def sample_get_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -476,7 +482,10 @@ async def sample_get_index(): "the individual field arguments should be set." ) - request = index_service.GetIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.GetIndexRequest): + request = index_service.GetIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -485,11 +494,9 @@ async def sample_get_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_index, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -577,8 +584,8 @@ async def sample_list_indexes(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -586,7 +593,10 @@ async def sample_list_indexes(): "the individual field arguments should be set." ) - request = index_service.ListIndexesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.ListIndexesRequest): + request = index_service.ListIndexesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -595,11 +605,9 @@ async def sample_list_indexes(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_indexes, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_indexes + ] # Certain fields should be provided within the metadata header; # add these here. @@ -710,8 +718,8 @@ async def sample_update_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -719,7 +727,10 @@ async def sample_update_index(): "the individual field arguments should be set." ) - request = index_service.UpdateIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.UpdateIndexRequest): + request = index_service.UpdateIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -730,11 +741,9 @@ async def sample_update_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_index, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -844,8 +853,8 @@ async def sample_delete_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -853,7 +862,10 @@ async def sample_delete_index(): "the individual field arguments should be set." ) - request = index_service.DeleteIndexRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.DeleteIndexRequest): + request = index_service.DeleteIndexRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -862,11 +874,9 @@ async def sample_delete_index(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_index, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_index + ] # Certain fields should be provided within the metadata header; # add these here. @@ -949,15 +959,16 @@ async def sample_upsert_datapoints(): """ # Create or coerce a protobuf request object. - request = index_service.UpsertDatapointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.UpsertDatapointsRequest): + request = index_service.UpsertDatapointsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.upsert_datapoints, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.upsert_datapoints + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1032,15 +1043,16 @@ async def sample_remove_datapoints(): """ # Create or coerce a protobuf request object. - request = index_service.RemoveDatapointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, index_service.RemoveDatapointsRequest): + request = index_service.RemoveDatapointsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.remove_datapoints, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.remove_datapoints + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_service/client.py index c4389e12f4..cf715bdc91 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -565,7 +566,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, IndexServiceTransport]] = None, + transport: Optional[ + Union[str, IndexServiceTransport, Callable[..., IndexServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -577,9 +580,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, IndexServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,IndexServiceTransport,Callable[..., IndexServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the IndexServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -688,8 +693,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[IndexServiceTransport], Callable[..., IndexServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., IndexServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -780,8 +792,8 @@ def sample_create_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, index]) if request is not None and has_flattened_params: raise ValueError( @@ -789,10 +801,8 @@ def sample_create_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.CreateIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.CreateIndexRequest): request = index_service.CreateIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -897,8 +907,8 @@ def sample_get_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -906,10 +916,8 @@ def sample_get_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.GetIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.GetIndexRequest): request = index_service.GetIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1007,8 +1015,8 @@ def sample_list_indexes(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1016,10 +1024,8 @@ def sample_list_indexes(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.ListIndexesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.ListIndexesRequest): request = index_service.ListIndexesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1140,8 +1146,8 @@ def sample_update_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([index, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1149,10 +1155,8 @@ def sample_update_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.UpdateIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.UpdateIndexRequest): request = index_service.UpdateIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1274,8 +1278,8 @@ def sample_delete_index(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1283,10 +1287,8 @@ def sample_delete_index(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a index_service.DeleteIndexRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.DeleteIndexRequest): request = index_service.DeleteIndexRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1379,10 +1381,8 @@ def sample_upsert_datapoints(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a index_service.UpsertDatapointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.UpsertDatapointsRequest): request = index_service.UpsertDatapointsRequest(request) @@ -1463,10 +1463,8 @@ def sample_remove_datapoints(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a index_service.RemoveDatapointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, index_service.RemoveDatapointsRequest): request = index_service.RemoveDatapointsRequest(request) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py index 7da6c81566..33d16a6b7b 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py index ef2e9639e6..4c4d28974a 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -454,6 +460,46 @@ def remove_datapoints( ) return self._stubs["remove_datapoints"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_index: gapic_v1.method_async.wrap_method( + self.create_index, + default_timeout=5.0, + client_info=client_info, + ), + self.get_index: gapic_v1.method_async.wrap_method( + self.get_index, + default_timeout=5.0, + client_info=client_info, + ), + self.list_indexes: gapic_v1.method_async.wrap_method( + self.list_indexes, + default_timeout=5.0, + client_info=client_info, + ), + self.update_index: gapic_v1.method_async.wrap_method( + self.update_index, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_index: gapic_v1.method_async.wrap_method( + self.delete_index, + default_timeout=5.0, + client_info=client_info, + ), + self.upsert_datapoints: gapic_v1.method_async.wrap_method( + self.upsert_datapoints, + default_timeout=None, + client_info=client_info, + ), + self.remove_datapoints: gapic_v1.method_async.wrap_method( + self.remove_datapoints, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/rest.py index 817a783095..56ee7fd395 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/rest.py @@ -840,10 +840,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1214,10 +1210,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1576,10 +1568,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1954,10 +1942,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2332,10 +2316,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -3954,10 +3934,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4385,10 +4361,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4807,10 +4779,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5246,10 +5214,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5685,10 +5649,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 4168607063..a225f26a2f 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -279,7 +280,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, JobServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, JobServiceTransport, Callable[..., JobServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -291,9 +294,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.JobServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,JobServiceTransport,Callable[..., JobServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the JobServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -420,8 +425,8 @@ async def sample_create_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: raise ValueError( @@ -429,7 +434,10 @@ async def sample_create_custom_job(): "the individual field arguments should be set." ) - request = job_service.CreateCustomJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateCustomJobRequest): + request = job_service.CreateCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -440,11 +448,9 @@ async def sample_create_custom_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_custom_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_custom_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -533,8 +539,8 @@ async def sample_get_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -542,7 +548,10 @@ async def sample_get_custom_job(): "the individual field arguments should be set." ) - request = job_service.GetCustomJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetCustomJobRequest): + request = job_service.GetCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -551,11 +560,9 @@ async def sample_get_custom_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_custom_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_custom_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -643,8 +650,8 @@ async def sample_list_custom_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -652,7 +659,10 @@ async def sample_list_custom_jobs(): "the individual field arguments should be set." ) - request = job_service.ListCustomJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListCustomJobsRequest): + request = job_service.ListCustomJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -661,11 +671,9 @@ async def sample_list_custom_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_custom_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_custom_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -772,8 +780,8 @@ async def sample_delete_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -781,7 +789,10 @@ async def sample_delete_custom_job(): "the individual field arguments should be set." ) - request = job_service.DeleteCustomJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteCustomJobRequest): + request = job_service.DeleteCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -790,11 +801,9 @@ async def sample_delete_custom_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_custom_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_custom_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -888,8 +897,8 @@ async def sample_cancel_custom_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -897,7 +906,10 @@ async def sample_cancel_custom_job(): "the individual field arguments should be set." ) - request = job_service.CancelCustomJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelCustomJobRequest): + request = job_service.CancelCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -906,11 +918,9 @@ async def sample_cancel_custom_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_custom_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_custom_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1008,8 +1018,8 @@ async def sample_create_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1017,7 +1027,10 @@ async def sample_create_data_labeling_job(): "the individual field arguments should be set." ) - request = job_service.CreateDataLabelingJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateDataLabelingJobRequest): + request = job_service.CreateDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1028,11 +1041,9 @@ async def sample_create_data_labeling_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_data_labeling_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_data_labeling_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1116,8 +1127,8 @@ async def sample_get_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1125,7 +1136,10 @@ async def sample_get_data_labeling_job(): "the individual field arguments should be set." ) - request = job_service.GetDataLabelingJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetDataLabelingJobRequest): + request = job_service.GetDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1134,11 +1148,9 @@ async def sample_get_data_labeling_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_data_labeling_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_data_labeling_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1225,8 +1237,8 @@ async def sample_list_data_labeling_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1234,7 +1246,10 @@ async def sample_list_data_labeling_jobs(): "the individual field arguments should be set." ) - request = job_service.ListDataLabelingJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListDataLabelingJobsRequest): + request = job_service.ListDataLabelingJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1243,11 +1258,9 @@ async def sample_list_data_labeling_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_data_labeling_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_data_labeling_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1354,8 +1367,8 @@ async def sample_delete_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1363,7 +1376,10 @@ async def sample_delete_data_labeling_job(): "the individual field arguments should be set." ) - request = job_service.DeleteDataLabelingJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteDataLabelingJobRequest): + request = job_service.DeleteDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1372,11 +1388,9 @@ async def sample_delete_data_labeling_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_data_labeling_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_data_labeling_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1459,8 +1473,8 @@ async def sample_cancel_data_labeling_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1468,7 +1482,10 @@ async def sample_cancel_data_labeling_job(): "the individual field arguments should be set." ) - request = job_service.CancelDataLabelingJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelDataLabelingJobRequest): + request = job_service.CancelDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1477,11 +1494,9 @@ async def sample_cancel_data_labeling_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_data_labeling_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_data_labeling_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1588,8 +1603,8 @@ async def sample_create_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1597,7 +1612,10 @@ async def sample_create_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - request = job_service.CreateHyperparameterTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateHyperparameterTuningJobRequest): + request = job_service.CreateHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1608,11 +1626,9 @@ async def sample_create_hyperparameter_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1700,8 +1716,8 @@ async def sample_get_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1709,7 +1725,10 @@ async def sample_get_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - request = job_service.GetHyperparameterTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetHyperparameterTuningJobRequest): + request = job_service.GetHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1718,11 +1737,9 @@ async def sample_get_hyperparameter_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1812,8 +1829,8 @@ async def sample_list_hyperparameter_tuning_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1821,7 +1838,10 @@ async def sample_list_hyperparameter_tuning_jobs(): "the individual field arguments should be set." ) - request = job_service.ListHyperparameterTuningJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListHyperparameterTuningJobsRequest): + request = job_service.ListHyperparameterTuningJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1830,11 +1850,9 @@ async def sample_list_hyperparameter_tuning_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_hyperparameter_tuning_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_hyperparameter_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1943,8 +1961,8 @@ async def sample_delete_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1952,7 +1970,10 @@ async def sample_delete_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - request = job_service.DeleteHyperparameterTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteHyperparameterTuningJobRequest): + request = job_service.DeleteHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1961,11 +1982,9 @@ async def sample_delete_hyperparameter_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2063,8 +2082,8 @@ async def sample_cancel_hyperparameter_tuning_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2072,7 +2091,10 @@ async def sample_cancel_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - request = job_service.CancelHyperparameterTuningJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelHyperparameterTuningJobRequest): + request = job_service.CancelHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2081,11 +2103,9 @@ async def sample_cancel_hyperparameter_tuning_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2179,8 +2199,8 @@ async def sample_create_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, nas_job]) if request is not None and has_flattened_params: raise ValueError( @@ -2188,7 +2208,10 @@ async def sample_create_nas_job(): "the individual field arguments should be set." ) - request = job_service.CreateNasJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateNasJobRequest): + request = job_service.CreateNasJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2199,11 +2222,9 @@ async def sample_create_nas_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_nas_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_nas_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2286,8 +2307,8 @@ async def sample_get_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2295,7 +2316,10 @@ async def sample_get_nas_job(): "the individual field arguments should be set." ) - request = job_service.GetNasJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetNasJobRequest): + request = job_service.GetNasJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2304,11 +2328,9 @@ async def sample_get_nas_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_nas_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_nas_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2396,8 +2418,8 @@ async def sample_list_nas_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2405,7 +2427,10 @@ async def sample_list_nas_jobs(): "the individual field arguments should be set." ) - request = job_service.ListNasJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListNasJobsRequest): + request = job_service.ListNasJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2414,11 +2439,9 @@ async def sample_list_nas_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_nas_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_nas_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2525,8 +2548,8 @@ async def sample_delete_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2534,7 +2557,10 @@ async def sample_delete_nas_job(): "the individual field arguments should be set." ) - request = job_service.DeleteNasJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteNasJobRequest): + request = job_service.DeleteNasJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2543,11 +2569,9 @@ async def sample_delete_nas_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_nas_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_nas_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2641,8 +2665,8 @@ async def sample_cancel_nas_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2650,7 +2674,10 @@ async def sample_cancel_nas_job(): "the individual field arguments should be set." ) - request = job_service.CancelNasJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelNasJobRequest): + request = job_service.CancelNasJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2659,11 +2686,9 @@ async def sample_cancel_nas_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_nas_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_nas_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2746,8 +2771,8 @@ async def sample_get_nas_trial_detail(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2755,7 +2780,10 @@ async def sample_get_nas_trial_detail(): "the individual field arguments should be set." ) - request = job_service.GetNasTrialDetailRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetNasTrialDetailRequest): + request = job_service.GetNasTrialDetailRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2764,11 +2792,9 @@ async def sample_get_nas_trial_detail(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_nas_trial_detail, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_nas_trial_detail + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2855,8 +2881,8 @@ async def sample_list_nas_trial_details(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2864,7 +2890,10 @@ async def sample_list_nas_trial_details(): "the individual field arguments should be set." ) - request = job_service.ListNasTrialDetailsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListNasTrialDetailsRequest): + request = job_service.ListNasTrialDetailsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2873,11 +2902,9 @@ async def sample_list_nas_trial_details(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_nas_trial_details, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_nas_trial_details + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2996,8 +3023,8 @@ async def sample_create_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: raise ValueError( @@ -3005,7 +3032,10 @@ async def sample_create_batch_prediction_job(): "the individual field arguments should be set." ) - request = job_service.CreateBatchPredictionJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CreateBatchPredictionJobRequest): + request = job_service.CreateBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3016,11 +3046,9 @@ async def sample_create_batch_prediction_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_batch_prediction_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3109,8 +3137,8 @@ async def sample_get_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3118,7 +3146,10 @@ async def sample_get_batch_prediction_job(): "the individual field arguments should be set." ) - request = job_service.GetBatchPredictionJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetBatchPredictionJobRequest): + request = job_service.GetBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3127,11 +3158,9 @@ async def sample_get_batch_prediction_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_batch_prediction_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3221,8 +3250,8 @@ async def sample_list_batch_prediction_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3230,7 +3259,10 @@ async def sample_list_batch_prediction_jobs(): "the individual field arguments should be set." ) - request = job_service.ListBatchPredictionJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.ListBatchPredictionJobsRequest): + request = job_service.ListBatchPredictionJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3239,11 +3271,9 @@ async def sample_list_batch_prediction_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_batch_prediction_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_batch_prediction_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3353,8 +3383,8 @@ async def sample_delete_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3362,7 +3392,10 @@ async def sample_delete_batch_prediction_job(): "the individual field arguments should be set." ) - request = job_service.DeleteBatchPredictionJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.DeleteBatchPredictionJobRequest): + request = job_service.DeleteBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3371,11 +3404,9 @@ async def sample_delete_batch_prediction_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_batch_prediction_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3471,8 +3502,8 @@ async def sample_cancel_batch_prediction_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3480,7 +3511,10 @@ async def sample_cancel_batch_prediction_job(): "the individual field arguments should be set." ) - request = job_service.CancelBatchPredictionJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.CancelBatchPredictionJobRequest): + request = job_service.CancelBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3489,11 +3523,9 @@ async def sample_cancel_batch_prediction_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_batch_prediction_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3595,8 +3627,8 @@ async def sample_create_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: raise ValueError( @@ -3604,7 +3636,12 @@ async def sample_create_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.CreateModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.CreateModelDeploymentMonitoringJobRequest + ): + request = job_service.CreateModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3615,11 +3652,9 @@ async def sample_create_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_model_deployment_monitoring_job, - default_timeout=60.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3721,8 +3756,8 @@ async def sample_search_model_deployment_monitoring_stats_anomalies(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: raise ValueError( @@ -3730,9 +3765,14 @@ async def sample_search_model_deployment_monitoring_stats_anomalies(): "the individual field arguments should be set." ) - request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest + ): + request = job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3743,11 +3783,9 @@ async def sample_search_model_deployment_monitoring_stats_anomalies(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_model_deployment_monitoring_stats_anomalies, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_model_deployment_monitoring_stats_anomalies + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3852,8 +3890,8 @@ async def sample_get_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3861,7 +3899,10 @@ async def sample_get_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.GetModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, job_service.GetModelDeploymentMonitoringJobRequest): + request = job_service.GetModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3870,11 +3911,9 @@ async def sample_get_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_deployment_monitoring_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3964,8 +4003,8 @@ async def sample_list_model_deployment_monitoring_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3973,7 +4012,12 @@ async def sample_list_model_deployment_monitoring_jobs(): "the individual field arguments should be set." ) - request = job_service.ListModelDeploymentMonitoringJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.ListModelDeploymentMonitoringJobsRequest + ): + request = job_service.ListModelDeploymentMonitoringJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3982,11 +4026,9 @@ async def sample_list_model_deployment_monitoring_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_deployment_monitoring_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_deployment_monitoring_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4128,8 +4170,8 @@ async def sample_update_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -4137,7 +4179,12 @@ async def sample_update_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.UpdateModelDeploymentMonitoringJobRequest + ): + request = job_service.UpdateModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4148,11 +4195,9 @@ async def sample_update_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_model_deployment_monitoring_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4267,8 +4312,8 @@ async def sample_delete_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4276,7 +4321,12 @@ async def sample_delete_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.DeleteModelDeploymentMonitoringJobRequest + ): + request = job_service.DeleteModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4285,11 +4335,9 @@ async def sample_delete_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model_deployment_monitoring_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4377,8 +4425,8 @@ async def sample_pause_model_deployment_monitoring_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4386,7 +4434,12 @@ async def sample_pause_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.PauseModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.PauseModelDeploymentMonitoringJobRequest + ): + request = job_service.PauseModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4395,11 +4448,9 @@ async def sample_pause_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.pause_model_deployment_monitoring_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.pause_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4475,8 +4526,8 @@ async def sample_resume_model_deployment_monitoring_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4484,7 +4535,12 @@ async def sample_resume_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, job_service.ResumeModelDeploymentMonitoringJobRequest + ): + request = job_service.ResumeModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4493,11 +4549,9 @@ async def sample_resume_model_deployment_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.resume_model_deployment_monitoring_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.resume_model_deployment_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index d1536a5608..8f1fd51d84 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -903,7 +904,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, JobServiceTransport]] = None, + transport: Optional[ + Union[str, JobServiceTransport, Callable[..., JobServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -915,9 +918,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, JobServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,JobServiceTransport,Callable[..., JobServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the JobServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -1026,8 +1031,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[JobServiceTransport], Callable[..., JobServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., JobServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -1119,8 +1131,8 @@ def sample_create_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1128,10 +1140,8 @@ def sample_create_custom_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateCustomJobRequest): request = job_service.CreateCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1232,8 +1242,8 @@ def sample_get_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1241,10 +1251,8 @@ def sample_get_custom_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetCustomJobRequest): request = job_service.GetCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1342,8 +1350,8 @@ def sample_list_custom_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1351,10 +1359,8 @@ def sample_list_custom_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListCustomJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListCustomJobsRequest): request = job_service.ListCustomJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1471,8 +1477,8 @@ def sample_delete_custom_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1480,10 +1486,8 @@ def sample_delete_custom_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteCustomJobRequest): request = job_service.DeleteCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1587,8 +1591,8 @@ def sample_cancel_custom_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1596,10 +1600,8 @@ def sample_cancel_custom_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelCustomJobRequest): request = job_service.CancelCustomJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1707,8 +1709,8 @@ def sample_create_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1716,10 +1718,8 @@ def sample_create_data_labeling_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateDataLabelingJobRequest): request = job_service.CreateDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1815,8 +1815,8 @@ def sample_get_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1824,10 +1824,8 @@ def sample_get_data_labeling_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetDataLabelingJobRequest): request = job_service.GetDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1924,8 +1922,8 @@ def sample_list_data_labeling_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1933,10 +1931,8 @@ def sample_list_data_labeling_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListDataLabelingJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListDataLabelingJobsRequest): request = job_service.ListDataLabelingJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2053,8 +2049,8 @@ def sample_delete_data_labeling_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2062,10 +2058,8 @@ def sample_delete_data_labeling_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteDataLabelingJobRequest): request = job_service.DeleteDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2158,8 +2152,8 @@ def sample_cancel_data_labeling_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2167,10 +2161,8 @@ def sample_cancel_data_labeling_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelDataLabelingJobRequest): request = job_service.CancelDataLabelingJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2287,8 +2279,8 @@ def sample_create_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: raise ValueError( @@ -2296,10 +2288,8 @@ def sample_create_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateHyperparameterTuningJobRequest): request = job_service.CreateHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2401,8 +2391,8 @@ def sample_get_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2410,10 +2400,8 @@ def sample_get_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetHyperparameterTuningJobRequest): request = job_service.GetHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2515,8 +2503,8 @@ def sample_list_hyperparameter_tuning_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2524,10 +2512,8 @@ def sample_list_hyperparameter_tuning_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListHyperparameterTuningJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListHyperparameterTuningJobsRequest): request = job_service.ListHyperparameterTuningJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2648,8 +2634,8 @@ def sample_delete_hyperparameter_tuning_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2657,10 +2643,8 @@ def sample_delete_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteHyperparameterTuningJobRequest): request = job_service.DeleteHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2770,8 +2754,8 @@ def sample_cancel_hyperparameter_tuning_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2779,10 +2763,8 @@ def sample_cancel_hyperparameter_tuning_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelHyperparameterTuningJobRequest): request = job_service.CancelHyperparameterTuningJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2888,8 +2870,8 @@ def sample_create_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, nas_job]) if request is not None and has_flattened_params: raise ValueError( @@ -2897,10 +2879,8 @@ def sample_create_nas_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateNasJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateNasJobRequest): request = job_service.CreateNasJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2995,8 +2975,8 @@ def sample_get_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3004,10 +2984,8 @@ def sample_get_nas_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetNasJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetNasJobRequest): request = job_service.GetNasJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3105,8 +3083,8 @@ def sample_list_nas_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3114,10 +3092,8 @@ def sample_list_nas_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListNasJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListNasJobsRequest): request = job_service.ListNasJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3234,8 +3210,8 @@ def sample_delete_nas_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3243,10 +3219,8 @@ def sample_delete_nas_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteNasJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteNasJobRequest): request = job_service.DeleteNasJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3350,8 +3324,8 @@ def sample_cancel_nas_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3359,10 +3333,8 @@ def sample_cancel_nas_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelNasJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelNasJobRequest): request = job_service.CancelNasJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3455,8 +3427,8 @@ def sample_get_nas_trial_detail(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3464,10 +3436,8 @@ def sample_get_nas_trial_detail(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetNasTrialDetailRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetNasTrialDetailRequest): request = job_service.GetNasTrialDetailRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3564,8 +3534,8 @@ def sample_list_nas_trial_details(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3573,10 +3543,8 @@ def sample_list_nas_trial_details(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListNasTrialDetailsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListNasTrialDetailsRequest): request = job_service.ListNasTrialDetailsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3705,8 +3673,8 @@ def sample_create_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: raise ValueError( @@ -3714,10 +3682,8 @@ def sample_create_batch_prediction_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CreateBatchPredictionJobRequest): request = job_service.CreateBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3820,8 +3786,8 @@ def sample_get_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3829,10 +3795,8 @@ def sample_get_batch_prediction_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetBatchPredictionJobRequest): request = job_service.GetBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3932,8 +3896,8 @@ def sample_list_batch_prediction_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3941,10 +3905,8 @@ def sample_list_batch_prediction_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListBatchPredictionJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.ListBatchPredictionJobsRequest): request = job_service.ListBatchPredictionJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4066,8 +4028,8 @@ def sample_delete_batch_prediction_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4075,10 +4037,8 @@ def sample_delete_batch_prediction_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.DeleteBatchPredictionJobRequest): request = job_service.DeleteBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4186,8 +4146,8 @@ def sample_cancel_batch_prediction_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4195,10 +4155,8 @@ def sample_cancel_batch_prediction_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.CancelBatchPredictionJobRequest): request = job_service.CancelBatchPredictionJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4312,8 +4270,8 @@ def sample_create_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_deployment_monitoring_job]) if request is not None and has_flattened_params: raise ValueError( @@ -4321,10 +4279,8 @@ def sample_create_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.CreateModelDeploymentMonitoringJobRequest ): @@ -4444,8 +4400,8 @@ def sample_search_model_deployment_monitoring_stats_anomalies(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, deployed_model_id]) if request is not None and has_flattened_params: raise ValueError( @@ -4453,10 +4409,8 @@ def sample_search_model_deployment_monitoring_stats_anomalies(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest ): @@ -4581,8 +4535,8 @@ def sample_get_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4590,10 +4544,8 @@ def sample_get_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, job_service.GetModelDeploymentMonitoringJobRequest): request = job_service.GetModelDeploymentMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4695,8 +4647,8 @@ def sample_list_model_deployment_monitoring_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -4704,10 +4656,8 @@ def sample_list_model_deployment_monitoring_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListModelDeploymentMonitoringJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.ListModelDeploymentMonitoringJobsRequest ): @@ -4863,8 +4813,8 @@ def sample_update_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_deployment_monitoring_job, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -4872,10 +4822,8 @@ def sample_update_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.UpdateModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.UpdateModelDeploymentMonitoringJobRequest ): @@ -5008,8 +4956,8 @@ def sample_delete_model_deployment_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -5017,10 +4965,8 @@ def sample_delete_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.DeleteModelDeploymentMonitoringJobRequest ): @@ -5122,8 +5068,8 @@ def sample_pause_model_deployment_monitoring_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -5131,10 +5077,8 @@ def sample_pause_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.PauseModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.PauseModelDeploymentMonitoringJobRequest ): @@ -5224,8 +5168,8 @@ def sample_resume_model_deployment_monitoring_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -5233,10 +5177,8 @@ def sample_resume_model_deployment_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ResumeModelDeploymentMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, job_service.ResumeModelDeploymentMonitoringJobRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index 21fd308f1f..c07769d96c 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -76,7 +76,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -96,14 +96,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -113,11 +116,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -144,7 +147,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -185,7 +188,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py index 3eb43274ef..793f7b3e9f 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -91,7 +93,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -121,7 +122,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -141,15 +142,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -159,11 +163,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -190,7 +194,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -230,7 +234,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -1363,6 +1369,186 @@ def resume_model_deployment_monitoring_job( ) return self._stubs["resume_model_deployment_monitoring_job"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_custom_job: gapic_v1.method_async.wrap_method( + self.create_custom_job, + default_timeout=5.0, + client_info=client_info, + ), + self.get_custom_job: gapic_v1.method_async.wrap_method( + self.get_custom_job, + default_timeout=5.0, + client_info=client_info, + ), + self.list_custom_jobs: gapic_v1.method_async.wrap_method( + self.list_custom_jobs, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_custom_job: gapic_v1.method_async.wrap_method( + self.delete_custom_job, + default_timeout=5.0, + client_info=client_info, + ), + self.cancel_custom_job: gapic_v1.method_async.wrap_method( + self.cancel_custom_job, + default_timeout=5.0, + client_info=client_info, + ), + self.create_data_labeling_job: gapic_v1.method_async.wrap_method( + self.create_data_labeling_job, + default_timeout=5.0, + client_info=client_info, + ), + self.get_data_labeling_job: gapic_v1.method_async.wrap_method( + self.get_data_labeling_job, + default_timeout=5.0, + client_info=client_info, + ), + self.list_data_labeling_jobs: gapic_v1.method_async.wrap_method( + self.list_data_labeling_jobs, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_data_labeling_job: gapic_v1.method_async.wrap_method( + self.delete_data_labeling_job, + default_timeout=5.0, + client_info=client_info, + ), + self.cancel_data_labeling_job: gapic_v1.method_async.wrap_method( + self.cancel_data_labeling_job, + default_timeout=5.0, + client_info=client_info, + ), + self.create_hyperparameter_tuning_job: gapic_v1.method_async.wrap_method( + self.create_hyperparameter_tuning_job, + default_timeout=5.0, + client_info=client_info, + ), + self.get_hyperparameter_tuning_job: gapic_v1.method_async.wrap_method( + self.get_hyperparameter_tuning_job, + default_timeout=5.0, + client_info=client_info, + ), + self.list_hyperparameter_tuning_jobs: gapic_v1.method_async.wrap_method( + self.list_hyperparameter_tuning_jobs, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_hyperparameter_tuning_job: gapic_v1.method_async.wrap_method( + self.delete_hyperparameter_tuning_job, + default_timeout=5.0, + client_info=client_info, + ), + self.cancel_hyperparameter_tuning_job: gapic_v1.method_async.wrap_method( + self.cancel_hyperparameter_tuning_job, + default_timeout=5.0, + client_info=client_info, + ), + self.create_nas_job: gapic_v1.method_async.wrap_method( + self.create_nas_job, + default_timeout=None, + client_info=client_info, + ), + self.get_nas_job: gapic_v1.method_async.wrap_method( + self.get_nas_job, + default_timeout=None, + client_info=client_info, + ), + self.list_nas_jobs: gapic_v1.method_async.wrap_method( + self.list_nas_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_nas_job: gapic_v1.method_async.wrap_method( + self.delete_nas_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_nas_job: gapic_v1.method_async.wrap_method( + self.cancel_nas_job, + default_timeout=None, + client_info=client_info, + ), + self.get_nas_trial_detail: gapic_v1.method_async.wrap_method( + self.get_nas_trial_detail, + default_timeout=None, + client_info=client_info, + ), + self.list_nas_trial_details: gapic_v1.method_async.wrap_method( + self.list_nas_trial_details, + default_timeout=None, + client_info=client_info, + ), + self.create_batch_prediction_job: gapic_v1.method_async.wrap_method( + self.create_batch_prediction_job, + default_timeout=5.0, + client_info=client_info, + ), + self.get_batch_prediction_job: gapic_v1.method_async.wrap_method( + self.get_batch_prediction_job, + default_timeout=5.0, + client_info=client_info, + ), + self.list_batch_prediction_jobs: gapic_v1.method_async.wrap_method( + self.list_batch_prediction_jobs, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_batch_prediction_job: gapic_v1.method_async.wrap_method( + self.delete_batch_prediction_job, + default_timeout=5.0, + client_info=client_info, + ), + self.cancel_batch_prediction_job: gapic_v1.method_async.wrap_method( + self.cancel_batch_prediction_job, + default_timeout=5.0, + client_info=client_info, + ), + self.create_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.create_model_deployment_monitoring_job, + default_timeout=60.0, + client_info=client_info, + ), + self.search_model_deployment_monitoring_stats_anomalies: gapic_v1.method_async.wrap_method( + self.search_model_deployment_monitoring_stats_anomalies, + default_timeout=5.0, + client_info=client_info, + ), + self.get_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.get_model_deployment_monitoring_job, + default_timeout=5.0, + client_info=client_info, + ), + self.list_model_deployment_monitoring_jobs: gapic_v1.method_async.wrap_method( + self.list_model_deployment_monitoring_jobs, + default_timeout=5.0, + client_info=client_info, + ), + self.update_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.update_model_deployment_monitoring_job, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.delete_model_deployment_monitoring_job, + default_timeout=5.0, + client_info=client_info, + ), + self.pause_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.pause_model_deployment_monitoring_job, + default_timeout=5.0, + client_info=client_info, + ), + self.resume_model_deployment_monitoring_job: gapic_v1.method_async.wrap_method( + self.resume_model_deployment_monitoring_job, + default_timeout=5.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/rest.py index e68b1159e1..ad8ff4813f 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/rest.py @@ -1646,10 +1646,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -2020,10 +2016,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2382,10 +2374,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2760,10 +2748,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3138,10 +3122,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -7577,10 +7557,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -8008,10 +7984,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -8430,10 +8402,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -8869,10 +8837,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -9308,10 +9272,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/async_client.py index f8f726db49..451e42770a 100644 --- a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -197,7 +198,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, LlmUtilityServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + LlmUtilityServiceTransport, + Callable[..., LlmUtilityServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -209,9 +216,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.LlmUtilityServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,LlmUtilityServiceTransport,Callable[..., LlmUtilityServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the LlmUtilityServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -336,8 +345,8 @@ async def sample_compute_tokens(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances]) if request is not None and has_flattened_params: raise ValueError( @@ -345,7 +354,10 @@ async def sample_compute_tokens(): "the individual field arguments should be set." ) - request = llm_utility_service.ComputeTokensRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, llm_utility_service.ComputeTokensRequest): + request = llm_utility_service.ComputeTokensRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -356,11 +368,9 @@ async def sample_compute_tokens(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.compute_tokens, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.compute_tokens + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/client.py b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/client.py index 8db931037c..84bbee97ee 100644 --- a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -532,7 +533,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, LlmUtilityServiceTransport]] = None, + transport: Optional[ + Union[ + str, + LlmUtilityServiceTransport, + Callable[..., LlmUtilityServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -544,9 +551,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, LlmUtilityServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,LlmUtilityServiceTransport,Callable[..., LlmUtilityServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the LlmUtilityServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -658,8 +667,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[LlmUtilityServiceTransport], + Callable[..., LlmUtilityServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., LlmUtilityServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -749,8 +766,8 @@ def sample_compute_tokens(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances]) if request is not None and has_flattened_params: raise ValueError( @@ -758,10 +775,8 @@ def sample_compute_tokens(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a llm_utility_service.ComputeTokensRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, llm_utility_service.ComputeTokensRequest): request = llm_utility_service.ComputeTokensRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/grpc.py index 7fc996312a..af76babd19 100644 --- a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/grpc.py @@ -54,7 +54,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -74,14 +74,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -91,11 +94,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -121,7 +124,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -162,7 +165,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/grpc_asyncio.py index ae476ceafe..fcf96c4fd5 100644 --- a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -69,7 +71,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -99,7 +100,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -119,15 +120,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -137,11 +141,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -167,7 +171,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -207,7 +211,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -265,6 +271,16 @@ def compute_tokens( ) return self._stubs["compute_tokens"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.compute_tokens: gapic_v1.method_async.wrap_method( + self.compute_tokens, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/rest.py index 1b4290be1e..6b0fcd6867 100644 --- a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/rest.py @@ -1311,10 +1311,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1742,10 +1738,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2164,10 +2156,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2603,10 +2591,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3042,10 +3026,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/match_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/match_service/async_client.py index 84109f4d4c..26f7ea57c7 100644 --- a/google/cloud/aiplatform_v1beta1/services/match_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/match_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -197,7 +198,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, MatchServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, MatchServiceTransport, Callable[..., MatchServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -209,9 +212,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.MatchServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MatchServiceTransport,Callable[..., MatchServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MatchServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -312,15 +317,16 @@ async def sample_find_neighbors(): """ # Create or coerce a protobuf request object. - request = match_service.FindNeighborsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, match_service.FindNeighborsRequest): + request = match_service.FindNeighborsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.find_neighbors, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.find_neighbors + ] # Certain fields should be provided within the metadata header; # add these here. @@ -399,15 +405,16 @@ async def sample_read_index_datapoints(): """ # Create or coerce a protobuf request object. - request = match_service.ReadIndexDatapointsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, match_service.ReadIndexDatapointsRequest): + request = match_service.ReadIndexDatapointsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_index_datapoints, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_index_datapoints + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/match_service/client.py b/google/cloud/aiplatform_v1beta1/services/match_service/client.py index d6b97bb1c1..1966d98a0c 100644 --- a/google/cloud/aiplatform_v1beta1/services/match_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/match_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -532,7 +533,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, MatchServiceTransport]] = None, + transport: Optional[ + Union[str, MatchServiceTransport, Callable[..., MatchServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -544,9 +547,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, MatchServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MatchServiceTransport,Callable[..., MatchServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MatchServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -655,8 +660,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[MatchServiceTransport], Callable[..., MatchServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., MatchServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -722,10 +734,8 @@ def sample_find_neighbors(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a match_service.FindNeighborsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, match_service.FindNeighborsRequest): request = match_service.FindNeighborsRequest(request) @@ -810,10 +820,8 @@ def sample_read_index_datapoints(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a match_service.ReadIndexDatapointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, match_service.ReadIndexDatapointsRequest): request = match_service.ReadIndexDatapointsRequest(request) diff --git a/google/cloud/aiplatform_v1beta1/services/match_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/match_service/transports/grpc.py index aa46d0e49f..a07cbdc9b2 100644 --- a/google/cloud/aiplatform_v1beta1/services/match_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/match_service/transports/grpc.py @@ -55,7 +55,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -75,14 +75,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -92,11 +95,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -122,7 +125,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -163,7 +166,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/match_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/match_service/transports/grpc_asyncio.py index 969371ce46..bc7bc5ec4a 100644 --- a/google/cloud/aiplatform_v1beta1/services/match_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/match_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -70,7 +72,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -100,7 +101,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -120,15 +121,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -138,11 +142,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -168,7 +172,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -208,7 +212,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -298,6 +304,21 @@ def read_index_datapoints( ) return self._stubs["read_index_datapoints"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.find_neighbors: gapic_v1.method_async.wrap_method( + self.find_neighbors, + default_timeout=None, + client_info=client_info, + ), + self.read_index_datapoints: gapic_v1.method_async.wrap_method( + self.read_index_datapoints, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/match_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/match_service/transports/rest.py index d749aa4613..1cb6711a52 100644 --- a/google/cloud/aiplatform_v1beta1/services/match_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/match_service/transports/rest.py @@ -1440,10 +1440,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1871,10 +1867,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2293,10 +2285,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2732,10 +2720,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3171,10 +3155,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py index 3c9f370376..2e6355ca16 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -229,7 +230,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, MetadataServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, MetadataServiceTransport, Callable[..., MetadataServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -241,9 +246,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.MetadataServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MetadataServiceTransport,Callable[..., MetadataServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MetadataServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -385,8 +392,8 @@ async def sample_create_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: raise ValueError( @@ -394,7 +401,10 @@ async def sample_create_metadata_store(): "the individual field arguments should be set." ) - request = metadata_service.CreateMetadataStoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateMetadataStoreRequest): + request = metadata_service.CreateMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -407,11 +417,9 @@ async def sample_create_metadata_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_metadata_store, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_metadata_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -504,8 +512,8 @@ async def sample_get_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -513,7 +521,10 @@ async def sample_get_metadata_store(): "the individual field arguments should be set." ) - request = metadata_service.GetMetadataStoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetMetadataStoreRequest): + request = metadata_service.GetMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -522,11 +533,9 @@ async def sample_get_metadata_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_metadata_store, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_metadata_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -616,8 +625,8 @@ async def sample_list_metadata_stores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -625,7 +634,10 @@ async def sample_list_metadata_stores(): "the individual field arguments should be set." ) - request = metadata_service.ListMetadataStoresRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListMetadataStoresRequest): + request = metadata_service.ListMetadataStoresRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -634,11 +646,9 @@ async def sample_list_metadata_stores(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_metadata_stores, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_metadata_stores + ] # Certain fields should be provided within the metadata header; # add these here. @@ -748,8 +758,8 @@ async def sample_delete_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -757,7 +767,10 @@ async def sample_delete_metadata_store(): "the individual field arguments should be set." ) - request = metadata_service.DeleteMetadataStoreRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.DeleteMetadataStoreRequest): + request = metadata_service.DeleteMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -766,11 +779,9 @@ async def sample_delete_metadata_store(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_metadata_store, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_metadata_store + ] # Certain fields should be provided within the metadata header; # add these here. @@ -882,8 +893,8 @@ async def sample_create_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: raise ValueError( @@ -891,7 +902,10 @@ async def sample_create_artifact(): "the individual field arguments should be set." ) - request = metadata_service.CreateArtifactRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateArtifactRequest): + request = metadata_service.CreateArtifactRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -904,11 +918,9 @@ async def sample_create_artifact(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_artifact, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_artifact + ] # Certain fields should be provided within the metadata header; # add these here. @@ -990,8 +1002,8 @@ async def sample_get_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -999,7 +1011,10 @@ async def sample_get_artifact(): "the individual field arguments should be set." ) - request = metadata_service.GetArtifactRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetArtifactRequest): + request = metadata_service.GetArtifactRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1008,11 +1023,9 @@ async def sample_get_artifact(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_artifact, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_artifact + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1100,8 +1113,8 @@ async def sample_list_artifacts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1109,7 +1122,10 @@ async def sample_list_artifacts(): "the individual field arguments should be set." ) - request = metadata_service.ListArtifactsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListArtifactsRequest): + request = metadata_service.ListArtifactsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1118,11 +1134,9 @@ async def sample_list_artifacts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_artifacts, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_artifacts + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1223,8 +1237,8 @@ async def sample_update_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1232,7 +1246,10 @@ async def sample_update_artifact(): "the individual field arguments should be set." ) - request = metadata_service.UpdateArtifactRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.UpdateArtifactRequest): + request = metadata_service.UpdateArtifactRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1243,11 +1260,9 @@ async def sample_update_artifact(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_artifact, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_artifact + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1347,8 +1362,8 @@ async def sample_delete_artifact(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1356,7 +1371,10 @@ async def sample_delete_artifact(): "the individual field arguments should be set." ) - request = metadata_service.DeleteArtifactRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.DeleteArtifactRequest): + request = metadata_service.DeleteArtifactRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1365,11 +1383,9 @@ async def sample_delete_artifact(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_artifact, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_artifact + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1468,8 +1484,8 @@ async def sample_purge_artifacts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1477,7 +1493,10 @@ async def sample_purge_artifacts(): "the individual field arguments should be set." ) - request = metadata_service.PurgeArtifactsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.PurgeArtifactsRequest): + request = metadata_service.PurgeArtifactsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1486,11 +1505,9 @@ async def sample_purge_artifacts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.purge_artifacts, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.purge_artifacts + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1602,8 +1619,8 @@ async def sample_create_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1611,7 +1628,10 @@ async def sample_create_context(): "the individual field arguments should be set." ) - request = metadata_service.CreateContextRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateContextRequest): + request = metadata_service.CreateContextRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1624,11 +1644,9 @@ async def sample_create_context(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_context, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_context + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1710,8 +1728,8 @@ async def sample_get_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1719,7 +1737,10 @@ async def sample_get_context(): "the individual field arguments should be set." ) - request = metadata_service.GetContextRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetContextRequest): + request = metadata_service.GetContextRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1728,11 +1749,9 @@ async def sample_get_context(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_context, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_context + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1820,8 +1839,8 @@ async def sample_list_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1829,7 +1848,10 @@ async def sample_list_contexts(): "the individual field arguments should be set." ) - request = metadata_service.ListContextsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListContextsRequest): + request = metadata_service.ListContextsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1838,11 +1860,9 @@ async def sample_list_contexts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_contexts, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_contexts + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1942,8 +1962,8 @@ async def sample_update_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1951,7 +1971,10 @@ async def sample_update_context(): "the individual field arguments should be set." ) - request = metadata_service.UpdateContextRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.UpdateContextRequest): + request = metadata_service.UpdateContextRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1962,11 +1985,9 @@ async def sample_update_context(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_context, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_context + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2066,8 +2087,8 @@ async def sample_delete_context(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2075,7 +2096,10 @@ async def sample_delete_context(): "the individual field arguments should be set." ) - request = metadata_service.DeleteContextRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.DeleteContextRequest): + request = metadata_service.DeleteContextRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2084,11 +2108,9 @@ async def sample_delete_context(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_context, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_context + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2187,8 +2209,8 @@ async def sample_purge_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2196,7 +2218,10 @@ async def sample_purge_contexts(): "the individual field arguments should be set." ) - request = metadata_service.PurgeContextsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.PurgeContextsRequest): + request = metadata_service.PurgeContextsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2205,11 +2230,9 @@ async def sample_purge_contexts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.purge_contexts, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.purge_contexts + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2327,8 +2350,8 @@ async def sample_add_context_artifacts_and_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: raise ValueError( @@ -2336,7 +2359,12 @@ async def sample_add_context_artifacts_and_executions(): "the individual field arguments should be set." ) - request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, metadata_service.AddContextArtifactsAndExecutionsRequest + ): + request = metadata_service.AddContextArtifactsAndExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2349,11 +2377,9 @@ async def sample_add_context_artifacts_and_executions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.add_context_artifacts_and_executions, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.add_context_artifacts_and_executions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2452,8 +2478,8 @@ async def sample_add_context_children(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: raise ValueError( @@ -2461,7 +2487,10 @@ async def sample_add_context_children(): "the individual field arguments should be set." ) - request = metadata_service.AddContextChildrenRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.AddContextChildrenRequest): + request = metadata_service.AddContextChildrenRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2472,11 +2501,9 @@ async def sample_add_context_children(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.add_context_children, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.add_context_children + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2573,8 +2600,8 @@ async def sample_remove_context_children(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: raise ValueError( @@ -2582,7 +2609,10 @@ async def sample_remove_context_children(): "the individual field arguments should be set." ) - request = metadata_service.RemoveContextChildrenRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.RemoveContextChildrenRequest): + request = metadata_service.RemoveContextChildrenRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2593,11 +2623,9 @@ async def sample_remove_context_children(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.remove_context_children, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.remove_context_children + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2692,8 +2720,8 @@ async def sample_query_context_lineage_subgraph(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: raise ValueError( @@ -2701,7 +2729,10 @@ async def sample_query_context_lineage_subgraph(): "the individual field arguments should be set." ) - request = metadata_service.QueryContextLineageSubgraphRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.QueryContextLineageSubgraphRequest): + request = metadata_service.QueryContextLineageSubgraphRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2710,11 +2741,9 @@ async def sample_query_context_lineage_subgraph(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_context_lineage_subgraph, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_context_lineage_subgraph + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2818,8 +2847,8 @@ async def sample_create_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: raise ValueError( @@ -2827,7 +2856,10 @@ async def sample_create_execution(): "the individual field arguments should be set." ) - request = metadata_service.CreateExecutionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateExecutionRequest): + request = metadata_service.CreateExecutionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2840,11 +2872,9 @@ async def sample_create_execution(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_execution, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_execution + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2926,8 +2956,8 @@ async def sample_get_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2935,7 +2965,10 @@ async def sample_get_execution(): "the individual field arguments should be set." ) - request = metadata_service.GetExecutionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetExecutionRequest): + request = metadata_service.GetExecutionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2944,11 +2977,9 @@ async def sample_get_execution(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_execution, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_execution + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3036,8 +3067,8 @@ async def sample_list_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3045,7 +3076,10 @@ async def sample_list_executions(): "the individual field arguments should be set." ) - request = metadata_service.ListExecutionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListExecutionsRequest): + request = metadata_service.ListExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3054,11 +3088,9 @@ async def sample_list_executions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_executions, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_executions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3159,8 +3191,8 @@ async def sample_update_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -3168,7 +3200,10 @@ async def sample_update_execution(): "the individual field arguments should be set." ) - request = metadata_service.UpdateExecutionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.UpdateExecutionRequest): + request = metadata_service.UpdateExecutionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3179,11 +3214,9 @@ async def sample_update_execution(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_execution, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_execution + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3283,8 +3316,8 @@ async def sample_delete_execution(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3292,7 +3325,10 @@ async def sample_delete_execution(): "the individual field arguments should be set." ) - request = metadata_service.DeleteExecutionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.DeleteExecutionRequest): + request = metadata_service.DeleteExecutionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3301,11 +3337,9 @@ async def sample_delete_execution(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_execution, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_execution + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3404,8 +3438,8 @@ async def sample_purge_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3413,7 +3447,10 @@ async def sample_purge_executions(): "the individual field arguments should be set." ) - request = metadata_service.PurgeExecutionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.PurgeExecutionsRequest): + request = metadata_service.PurgeExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3422,11 +3459,9 @@ async def sample_purge_executions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.purge_executions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.purge_executions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3530,8 +3565,8 @@ async def sample_add_execution_events(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: raise ValueError( @@ -3539,7 +3574,10 @@ async def sample_add_execution_events(): "the individual field arguments should be set." ) - request = metadata_service.AddExecutionEventsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.AddExecutionEventsRequest): + request = metadata_service.AddExecutionEventsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3550,11 +3588,9 @@ async def sample_add_execution_events(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.add_execution_events, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.add_execution_events + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3646,8 +3682,8 @@ async def sample_query_execution_inputs_and_outputs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: raise ValueError( @@ -3655,7 +3691,12 @@ async def sample_query_execution_inputs_and_outputs(): "the individual field arguments should be set." ) - request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, metadata_service.QueryExecutionInputsAndOutputsRequest + ): + request = metadata_service.QueryExecutionInputsAndOutputsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3664,11 +3705,9 @@ async def sample_query_execution_inputs_and_outputs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_execution_inputs_and_outputs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_execution_inputs_and_outputs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3782,8 +3821,8 @@ async def sample_create_metadata_schema(): Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: raise ValueError( @@ -3791,7 +3830,10 @@ async def sample_create_metadata_schema(): "the individual field arguments should be set." ) - request = metadata_service.CreateMetadataSchemaRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.CreateMetadataSchemaRequest): + request = metadata_service.CreateMetadataSchemaRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3804,11 +3846,9 @@ async def sample_create_metadata_schema(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_metadata_schema, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_metadata_schema + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3892,8 +3932,8 @@ async def sample_get_metadata_schema(): Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3901,7 +3941,10 @@ async def sample_get_metadata_schema(): "the individual field arguments should be set." ) - request = metadata_service.GetMetadataSchemaRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.GetMetadataSchemaRequest): + request = metadata_service.GetMetadataSchemaRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3910,11 +3953,9 @@ async def sample_get_metadata_schema(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_metadata_schema, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_metadata_schema + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4004,8 +4045,8 @@ async def sample_list_metadata_schemas(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -4013,7 +4054,10 @@ async def sample_list_metadata_schemas(): "the individual field arguments should be set." ) - request = metadata_service.ListMetadataSchemasRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, metadata_service.ListMetadataSchemasRequest): + request = metadata_service.ListMetadataSchemasRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4022,11 +4066,9 @@ async def sample_list_metadata_schemas(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_metadata_schemas, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_metadata_schemas + ] # Certain fields should be provided within the metadata header; # add these here. @@ -4130,8 +4172,8 @@ async def sample_query_artifact_lineage_subgraph(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: raise ValueError( @@ -4139,7 +4181,12 @@ async def sample_query_artifact_lineage_subgraph(): "the individual field arguments should be set." ) - request = metadata_service.QueryArtifactLineageSubgraphRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, metadata_service.QueryArtifactLineageSubgraphRequest + ): + request = metadata_service.QueryArtifactLineageSubgraphRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4148,11 +4195,9 @@ async def sample_query_artifact_lineage_subgraph(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_artifact_lineage_subgraph, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_artifact_lineage_subgraph + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py index a5975366ff..cea276b1ae 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -648,7 +649,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, MetadataServiceTransport]] = None, + transport: Optional[ + Union[ + str, MetadataServiceTransport, Callable[..., MetadataServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -660,9 +665,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, MetadataServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MetadataServiceTransport,Callable[..., MetadataServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MetadataServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -774,8 +781,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[MetadataServiceTransport], Callable[..., MetadataServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., MetadataServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -882,8 +896,8 @@ def sample_create_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_store, metadata_store_id]) if request is not None and has_flattened_params: raise ValueError( @@ -891,10 +905,8 @@ def sample_create_metadata_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateMetadataStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateMetadataStoreRequest): request = metadata_service.CreateMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1001,8 +1013,8 @@ def sample_get_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1010,10 +1022,8 @@ def sample_get_metadata_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetMetadataStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetMetadataStoreRequest): request = metadata_service.GetMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1113,8 +1123,8 @@ def sample_list_metadata_stores(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1122,10 +1132,8 @@ def sample_list_metadata_stores(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListMetadataStoresRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListMetadataStoresRequest): request = metadata_service.ListMetadataStoresRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1245,8 +1253,8 @@ def sample_delete_metadata_store(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1254,10 +1262,8 @@ def sample_delete_metadata_store(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.DeleteMetadataStoreRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.DeleteMetadataStoreRequest): request = metadata_service.DeleteMetadataStoreRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1379,8 +1385,8 @@ def sample_create_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, artifact, artifact_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1388,10 +1394,8 @@ def sample_create_artifact(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateArtifactRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateArtifactRequest): request = metadata_service.CreateArtifactRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1487,8 +1491,8 @@ def sample_get_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1496,10 +1500,8 @@ def sample_get_artifact(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetArtifactRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetArtifactRequest): request = metadata_service.GetArtifactRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1597,8 +1599,8 @@ def sample_list_artifacts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1606,10 +1608,8 @@ def sample_list_artifacts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListArtifactsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListArtifactsRequest): request = metadata_service.ListArtifactsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1720,8 +1720,8 @@ def sample_update_artifact(): Instance of a general artifact. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1729,10 +1729,8 @@ def sample_update_artifact(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.UpdateArtifactRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.UpdateArtifactRequest): request = metadata_service.UpdateArtifactRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1844,8 +1842,8 @@ def sample_delete_artifact(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1853,10 +1851,8 @@ def sample_delete_artifact(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.DeleteArtifactRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.DeleteArtifactRequest): request = metadata_service.DeleteArtifactRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1965,8 +1961,8 @@ def sample_purge_artifacts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1974,10 +1970,8 @@ def sample_purge_artifacts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.PurgeArtifactsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.PurgeArtifactsRequest): request = metadata_service.PurgeArtifactsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2099,8 +2093,8 @@ def sample_create_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, context, context_id]) if request is not None and has_flattened_params: raise ValueError( @@ -2108,10 +2102,8 @@ def sample_create_context(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateContextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateContextRequest): request = metadata_service.CreateContextRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2207,8 +2199,8 @@ def sample_get_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2216,10 +2208,8 @@ def sample_get_context(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetContextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetContextRequest): request = metadata_service.GetContextRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2317,8 +2307,8 @@ def sample_list_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2326,10 +2316,8 @@ def sample_list_contexts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListContextsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListContextsRequest): request = metadata_service.ListContextsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2439,8 +2427,8 @@ def sample_update_context(): Instance of a general context. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2448,10 +2436,8 @@ def sample_update_context(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.UpdateContextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.UpdateContextRequest): request = metadata_service.UpdateContextRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2563,8 +2549,8 @@ def sample_delete_context(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2572,10 +2558,8 @@ def sample_delete_context(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.DeleteContextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.DeleteContextRequest): request = metadata_service.DeleteContextRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2684,8 +2668,8 @@ def sample_purge_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2693,10 +2677,8 @@ def sample_purge_contexts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.PurgeContextsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.PurgeContextsRequest): request = metadata_service.PurgeContextsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2824,8 +2806,8 @@ def sample_add_context_artifacts_and_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, artifacts, executions]) if request is not None and has_flattened_params: raise ValueError( @@ -2833,10 +2815,8 @@ def sample_add_context_artifacts_and_executions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.AddContextArtifactsAndExecutionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, metadata_service.AddContextArtifactsAndExecutionsRequest ): @@ -2953,8 +2933,8 @@ def sample_add_context_children(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: raise ValueError( @@ -2962,10 +2942,8 @@ def sample_add_context_children(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.AddContextChildrenRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.AddContextChildrenRequest): request = metadata_service.AddContextChildrenRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3074,8 +3052,8 @@ def sample_remove_context_children(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context, child_contexts]) if request is not None and has_flattened_params: raise ValueError( @@ -3083,10 +3061,8 @@ def sample_remove_context_children(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.RemoveContextChildrenRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.RemoveContextChildrenRequest): request = metadata_service.RemoveContextChildrenRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3193,8 +3169,8 @@ def sample_query_context_lineage_subgraph(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([context]) if request is not None and has_flattened_params: raise ValueError( @@ -3202,10 +3178,8 @@ def sample_query_context_lineage_subgraph(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.QueryContextLineageSubgraphRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.QueryContextLineageSubgraphRequest): request = metadata_service.QueryContextLineageSubgraphRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3321,8 +3295,8 @@ def sample_create_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, execution, execution_id]) if request is not None and has_flattened_params: raise ValueError( @@ -3330,10 +3304,8 @@ def sample_create_execution(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateExecutionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateExecutionRequest): request = metadata_service.CreateExecutionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3429,8 +3401,8 @@ def sample_get_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3438,10 +3410,8 @@ def sample_get_execution(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetExecutionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetExecutionRequest): request = metadata_service.GetExecutionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3539,8 +3509,8 @@ def sample_list_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3548,10 +3518,8 @@ def sample_list_executions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListExecutionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListExecutionsRequest): request = metadata_service.ListExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3662,8 +3630,8 @@ def sample_update_execution(): Instance of a general execution. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -3671,10 +3639,8 @@ def sample_update_execution(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.UpdateExecutionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.UpdateExecutionRequest): request = metadata_service.UpdateExecutionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3786,8 +3752,8 @@ def sample_delete_execution(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3795,10 +3761,8 @@ def sample_delete_execution(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.DeleteExecutionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.DeleteExecutionRequest): request = metadata_service.DeleteExecutionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3907,8 +3871,8 @@ def sample_purge_executions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3916,10 +3880,8 @@ def sample_purge_executions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.PurgeExecutionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.PurgeExecutionsRequest): request = metadata_service.PurgeExecutionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4033,8 +3995,8 @@ def sample_add_execution_events(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution, events]) if request is not None and has_flattened_params: raise ValueError( @@ -4042,10 +4004,8 @@ def sample_add_execution_events(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.AddExecutionEventsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.AddExecutionEventsRequest): request = metadata_service.AddExecutionEventsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4149,8 +4109,8 @@ def sample_query_execution_inputs_and_outputs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([execution]) if request is not None and has_flattened_params: raise ValueError( @@ -4158,10 +4118,8 @@ def sample_query_execution_inputs_and_outputs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.QueryExecutionInputsAndOutputsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, metadata_service.QueryExecutionInputsAndOutputsRequest ): @@ -4289,8 +4247,8 @@ def sample_create_metadata_schema(): Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, metadata_schema, metadata_schema_id]) if request is not None and has_flattened_params: raise ValueError( @@ -4298,10 +4256,8 @@ def sample_create_metadata_schema(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.CreateMetadataSchemaRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.CreateMetadataSchemaRequest): request = metadata_service.CreateMetadataSchemaRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4399,8 +4355,8 @@ def sample_get_metadata_schema(): Instance of a general MetadataSchema. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -4408,10 +4364,8 @@ def sample_get_metadata_schema(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.GetMetadataSchemaRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.GetMetadataSchemaRequest): request = metadata_service.GetMetadataSchemaRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4511,8 +4465,8 @@ def sample_list_metadata_schemas(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -4520,10 +4474,8 @@ def sample_list_metadata_schemas(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.ListMetadataSchemasRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, metadata_service.ListMetadataSchemasRequest): request = metadata_service.ListMetadataSchemasRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4637,8 +4589,8 @@ def sample_query_artifact_lineage_subgraph(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([artifact]) if request is not None and has_flattened_params: raise ValueError( @@ -4646,10 +4598,8 @@ def sample_query_artifact_lineage_subgraph(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a metadata_service.QueryArtifactLineageSubgraphRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, metadata_service.QueryArtifactLineageSubgraphRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py index dc3095176b..b8087a632c 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py @@ -65,7 +65,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -85,14 +85,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -102,11 +105,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -133,7 +136,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -174,7 +177,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py index c76b9eec91..758a1c99ca 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -80,7 +82,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -110,7 +111,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -130,15 +131,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -148,11 +152,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -179,7 +183,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -219,7 +223,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -1204,6 +1210,171 @@ def query_artifact_lineage_subgraph( ) return self._stubs["query_artifact_lineage_subgraph"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_metadata_store: gapic_v1.method_async.wrap_method( + self.create_metadata_store, + default_timeout=5.0, + client_info=client_info, + ), + self.get_metadata_store: gapic_v1.method_async.wrap_method( + self.get_metadata_store, + default_timeout=5.0, + client_info=client_info, + ), + self.list_metadata_stores: gapic_v1.method_async.wrap_method( + self.list_metadata_stores, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_metadata_store: gapic_v1.method_async.wrap_method( + self.delete_metadata_store, + default_timeout=5.0, + client_info=client_info, + ), + self.create_artifact: gapic_v1.method_async.wrap_method( + self.create_artifact, + default_timeout=5.0, + client_info=client_info, + ), + self.get_artifact: gapic_v1.method_async.wrap_method( + self.get_artifact, + default_timeout=5.0, + client_info=client_info, + ), + self.list_artifacts: gapic_v1.method_async.wrap_method( + self.list_artifacts, + default_timeout=5.0, + client_info=client_info, + ), + self.update_artifact: gapic_v1.method_async.wrap_method( + self.update_artifact, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_artifact: gapic_v1.method_async.wrap_method( + self.delete_artifact, + default_timeout=None, + client_info=client_info, + ), + self.purge_artifacts: gapic_v1.method_async.wrap_method( + self.purge_artifacts, + default_timeout=None, + client_info=client_info, + ), + self.create_context: gapic_v1.method_async.wrap_method( + self.create_context, + default_timeout=5.0, + client_info=client_info, + ), + self.get_context: gapic_v1.method_async.wrap_method( + self.get_context, + default_timeout=5.0, + client_info=client_info, + ), + self.list_contexts: gapic_v1.method_async.wrap_method( + self.list_contexts, + default_timeout=5.0, + client_info=client_info, + ), + self.update_context: gapic_v1.method_async.wrap_method( + self.update_context, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_context: gapic_v1.method_async.wrap_method( + self.delete_context, + default_timeout=5.0, + client_info=client_info, + ), + self.purge_contexts: gapic_v1.method_async.wrap_method( + self.purge_contexts, + default_timeout=None, + client_info=client_info, + ), + self.add_context_artifacts_and_executions: gapic_v1.method_async.wrap_method( + self.add_context_artifacts_and_executions, + default_timeout=5.0, + client_info=client_info, + ), + self.add_context_children: gapic_v1.method_async.wrap_method( + self.add_context_children, + default_timeout=5.0, + client_info=client_info, + ), + self.remove_context_children: gapic_v1.method_async.wrap_method( + self.remove_context_children, + default_timeout=None, + client_info=client_info, + ), + self.query_context_lineage_subgraph: gapic_v1.method_async.wrap_method( + self.query_context_lineage_subgraph, + default_timeout=5.0, + client_info=client_info, + ), + self.create_execution: gapic_v1.method_async.wrap_method( + self.create_execution, + default_timeout=5.0, + client_info=client_info, + ), + self.get_execution: gapic_v1.method_async.wrap_method( + self.get_execution, + default_timeout=5.0, + client_info=client_info, + ), + self.list_executions: gapic_v1.method_async.wrap_method( + self.list_executions, + default_timeout=5.0, + client_info=client_info, + ), + self.update_execution: gapic_v1.method_async.wrap_method( + self.update_execution, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_execution: gapic_v1.method_async.wrap_method( + self.delete_execution, + default_timeout=None, + client_info=client_info, + ), + self.purge_executions: gapic_v1.method_async.wrap_method( + self.purge_executions, + default_timeout=None, + client_info=client_info, + ), + self.add_execution_events: gapic_v1.method_async.wrap_method( + self.add_execution_events, + default_timeout=5.0, + client_info=client_info, + ), + self.query_execution_inputs_and_outputs: gapic_v1.method_async.wrap_method( + self.query_execution_inputs_and_outputs, + default_timeout=5.0, + client_info=client_info, + ), + self.create_metadata_schema: gapic_v1.method_async.wrap_method( + self.create_metadata_schema, + default_timeout=5.0, + client_info=client_info, + ), + self.get_metadata_schema: gapic_v1.method_async.wrap_method( + self.get_metadata_schema, + default_timeout=5.0, + client_info=client_info, + ), + self.list_metadata_schemas: gapic_v1.method_async.wrap_method( + self.list_metadata_schemas, + default_timeout=5.0, + client_info=client_info, + ), + self.query_artifact_lineage_subgraph: gapic_v1.method_async.wrap_method( + self.query_artifact_lineage_subgraph, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/rest.py index c0d43fde79..53414eed71 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/rest.py @@ -1630,10 +1630,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -2004,10 +2000,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2366,10 +2358,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2744,10 +2732,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3122,10 +3106,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -7235,10 +7215,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -7666,10 +7642,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -8088,10 +8060,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -8527,10 +8495,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -8966,10 +8930,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index 4de47c4dc5..160c156a21 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -216,7 +217,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, MigrationServiceTransport, Callable[..., MigrationServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -228,9 +233,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.MigrationServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MigrationServiceTransport,Callable[..., MigrationServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MigrationServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -350,8 +357,8 @@ async def sample_search_migratable_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -359,7 +366,10 @@ async def sample_search_migratable_resources(): "the individual field arguments should be set." ) - request = migration_service.SearchMigratableResourcesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, migration_service.SearchMigratableResourcesRequest): + request = migration_service.SearchMigratableResourcesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -368,11 +378,9 @@ async def sample_search_migratable_resources(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_migratable_resources, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_migratable_resources + ] # Certain fields should be provided within the metadata header; # add these here. @@ -494,8 +502,8 @@ async def sample_batch_migrate_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: raise ValueError( @@ -503,7 +511,10 @@ async def sample_batch_migrate_resources(): "the individual field arguments should be set." ) - request = migration_service.BatchMigrateResourcesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, migration_service.BatchMigrateResourcesRequest): + request = migration_service.BatchMigrateResourcesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -514,11 +525,9 @@ async def sample_batch_migrate_resources(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_migrate_resources, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_migrate_resources + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 299c8b5e65..b54b12bcba 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -215,40 +216,40 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod @@ -664,7 +665,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, MigrationServiceTransport]] = None, + transport: Optional[ + Union[ + str, MigrationServiceTransport, Callable[..., MigrationServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -676,9 +681,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, MigrationServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,MigrationServiceTransport,Callable[..., MigrationServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the MigrationServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -790,8 +797,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[MigrationServiceTransport], + Callable[..., MigrationServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., MigrationServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -876,8 +891,8 @@ def sample_search_migratable_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -885,10 +900,8 @@ def sample_search_migratable_resources(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a migration_service.SearchMigratableResourcesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, migration_service.SearchMigratableResourcesRequest): request = migration_service.SearchMigratableResourcesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1022,8 +1035,8 @@ def sample_batch_migrate_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: raise ValueError( @@ -1031,10 +1044,8 @@ def sample_batch_migrate_resources(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a migration_service.BatchMigrateResourcesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, migration_service.BatchMigrateResourcesRequest): request = migration_service.BatchMigrateResourcesRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index 3fe98ccdd3..22b894d9ff 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -56,7 +56,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -76,14 +76,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -93,11 +96,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -124,7 +127,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -165,7 +168,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py index 54c0b76ddc..ca009de5aa 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -71,7 +73,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -101,7 +102,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -121,15 +122,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -139,11 +143,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -170,7 +174,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -210,7 +214,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -318,6 +324,21 @@ def batch_migrate_resources( ) return self._stubs["batch_migrate_resources"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.search_migratable_resources: gapic_v1.method_async.wrap_method( + self.search_migratable_resources, + default_timeout=None, + client_info=client_info, + ), + self.batch_migrate_resources: gapic_v1.method_async.wrap_method( + self.batch_migrate_resources, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/rest.py index 73f43ffadb..edbcfea96e 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/rest.py @@ -693,10 +693,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1067,10 +1063,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1429,10 +1421,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1807,10 +1795,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2185,10 +2169,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -3325,10 +3305,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -3756,10 +3732,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4178,10 +4150,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -4617,10 +4585,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5056,10 +5020,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py index 2c84cc5a80..f31fea297f 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -202,7 +203,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ModelGardenServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + ModelGardenServiceTransport, + Callable[..., ModelGardenServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -214,9 +221,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ModelGardenServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelGardenServiceTransport,Callable[..., ModelGardenServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelGardenServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -325,8 +334,8 @@ async def sample_get_publisher_model(): A Model Garden Publisher Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -334,7 +343,10 @@ async def sample_get_publisher_model(): "the individual field arguments should be set." ) - request = model_garden_service.GetPublisherModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_garden_service.GetPublisherModelRequest): + request = model_garden_service.GetPublisherModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -343,11 +355,9 @@ async def sample_get_publisher_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_publisher_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_publisher_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -436,8 +446,8 @@ async def sample_list_publisher_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -445,7 +455,10 @@ async def sample_list_publisher_models(): "the individual field arguments should be set." ) - request = model_garden_service.ListPublisherModelsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_garden_service.ListPublisherModelsRequest): + request = model_garden_service.ListPublisherModelsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -454,11 +467,9 @@ async def sample_list_publisher_models(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_publisher_models, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_publisher_models + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py index 225644210e..4c5ca703d4 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -529,7 +530,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ModelGardenServiceTransport]] = None, + transport: Optional[ + Union[ + str, + ModelGardenServiceTransport, + Callable[..., ModelGardenServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -541,9 +548,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ModelGardenServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelGardenServiceTransport,Callable[..., ModelGardenServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelGardenServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -655,8 +664,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ModelGardenServiceTransport], + Callable[..., ModelGardenServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ModelGardenServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -730,8 +747,8 @@ def sample_get_publisher_model(): A Model Garden Publisher Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -739,10 +756,8 @@ def sample_get_publisher_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_garden_service.GetPublisherModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_garden_service.GetPublisherModelRequest): request = model_garden_service.GetPublisherModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -841,8 +856,8 @@ def sample_list_publisher_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -850,10 +865,8 @@ def sample_list_publisher_models(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_garden_service.ListPublisherModelsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_garden_service.ListPublisherModelsRequest): request = model_garden_service.ListPublisherModelsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc.py index 285de87fea..45ac378c31 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc.py @@ -55,7 +55,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -75,14 +75,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -92,11 +95,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -122,7 +125,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -163,7 +166,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc_asyncio.py index e4840500b8..07bfaa0492 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -70,7 +72,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -100,7 +101,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -120,15 +121,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -138,11 +142,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -168,7 +172,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -208,7 +212,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -295,6 +301,21 @@ def list_publisher_models( ) return self._stubs["list_publisher_models"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.get_publisher_model: gapic_v1.method_async.wrap_method( + self.get_publisher_model, + default_timeout=None, + client_info=client_info, + ), + self.list_publisher_models: gapic_v1.method_async.wrap_method( + self.list_publisher_models, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/rest.py index a7125e9620..cf0dacfe5f 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/transports/rest.py @@ -1433,10 +1433,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1864,10 +1860,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2286,10 +2278,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2725,10 +2713,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3164,10 +3148,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/async_client.py index 2beca43a6c..dbbdfd96b6 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -242,7 +243,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ModelMonitoringServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + ModelMonitoringServiceTransport, + Callable[..., ModelMonitoringServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -254,9 +261,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ModelMonitoringServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelMonitoringServiceTransport,Callable[..., ModelMonitoringServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelMonitoringServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -381,8 +390,8 @@ async def sample_create_model_monitor(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_monitor]) if request is not None and has_flattened_params: raise ValueError( @@ -390,7 +399,10 @@ async def sample_create_model_monitor(): "the individual field arguments should be set." ) - request = model_monitoring_service.CreateModelMonitorRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_monitoring_service.CreateModelMonitorRequest): + request = model_monitoring_service.CreateModelMonitorRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -401,11 +413,9 @@ async def sample_create_model_monitor(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_model_monitor, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_model_monitor + ] # Certain fields should be provided within the metadata header; # add these here. @@ -514,8 +524,8 @@ async def sample_update_model_monitor(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_monitor, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -523,7 +533,10 @@ async def sample_update_model_monitor(): "the individual field arguments should be set." ) - request = model_monitoring_service.UpdateModelMonitorRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_monitoring_service.UpdateModelMonitorRequest): + request = model_monitoring_service.UpdateModelMonitorRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -534,11 +547,9 @@ async def sample_update_model_monitor(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_model_monitor, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_model_monitor + ] # Certain fields should be provided within the metadata header; # add these here. @@ -638,8 +649,8 @@ async def sample_get_model_monitor(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -647,7 +658,10 @@ async def sample_get_model_monitor(): "the individual field arguments should be set." ) - request = model_monitoring_service.GetModelMonitorRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_monitoring_service.GetModelMonitorRequest): + request = model_monitoring_service.GetModelMonitorRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -656,11 +670,9 @@ async def sample_get_model_monitor(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_monitor, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model_monitor + ] # Certain fields should be provided within the metadata header; # add these here. @@ -750,8 +762,8 @@ async def sample_list_model_monitors(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -759,7 +771,10 @@ async def sample_list_model_monitors(): "the individual field arguments should be set." ) - request = model_monitoring_service.ListModelMonitorsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_monitoring_service.ListModelMonitorsRequest): + request = model_monitoring_service.ListModelMonitorsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -768,11 +783,9 @@ async def sample_list_model_monitors(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_monitors, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_monitors + ] # Certain fields should be provided within the metadata header; # add these here. @@ -881,8 +894,8 @@ async def sample_delete_model_monitor(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -890,7 +903,10 @@ async def sample_delete_model_monitor(): "the individual field arguments should be set." ) - request = model_monitoring_service.DeleteModelMonitorRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_monitoring_service.DeleteModelMonitorRequest): + request = model_monitoring_service.DeleteModelMonitorRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -899,11 +915,9 @@ async def sample_delete_model_monitor(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model_monitor, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_model_monitor + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1007,8 +1021,8 @@ async def sample_create_model_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_monitoring_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1016,7 +1030,12 @@ async def sample_create_model_monitoring_job(): "the individual field arguments should be set." ) - request = model_monitoring_service.CreateModelMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_monitoring_service.CreateModelMonitoringJobRequest + ): + request = model_monitoring_service.CreateModelMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1027,11 +1046,9 @@ async def sample_create_model_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_model_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_model_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1118,8 +1135,8 @@ async def sample_get_model_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1127,7 +1144,12 @@ async def sample_get_model_monitoring_job(): "the individual field arguments should be set." ) - request = model_monitoring_service.GetModelMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_monitoring_service.GetModelMonitoringJobRequest + ): + request = model_monitoring_service.GetModelMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1136,11 +1158,9 @@ async def sample_get_model_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1234,8 +1254,8 @@ async def sample_list_model_monitoring_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1243,7 +1263,12 @@ async def sample_list_model_monitoring_jobs(): "the individual field arguments should be set." ) - request = model_monitoring_service.ListModelMonitoringJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_monitoring_service.ListModelMonitoringJobsRequest + ): + request = model_monitoring_service.ListModelMonitoringJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1252,11 +1277,9 @@ async def sample_list_model_monitoring_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_monitoring_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_monitoring_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1365,8 +1388,8 @@ async def sample_delete_model_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1374,7 +1397,12 @@ async def sample_delete_model_monitoring_job(): "the individual field arguments should be set." ) - request = model_monitoring_service.DeleteModelMonitoringJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_monitoring_service.DeleteModelMonitoringJobRequest + ): + request = model_monitoring_service.DeleteModelMonitoringJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1383,11 +1411,9 @@ async def sample_delete_model_monitoring_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model_monitoring_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_model_monitoring_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1485,8 +1511,8 @@ async def sample_search_model_monitoring_stats(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_monitor]) if request is not None and has_flattened_params: raise ValueError( @@ -1494,7 +1520,14 @@ async def sample_search_model_monitoring_stats(): "the individual field arguments should be set." ) - request = model_monitoring_service.SearchModelMonitoringStatsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_monitoring_service.SearchModelMonitoringStatsRequest + ): + request = model_monitoring_service.SearchModelMonitoringStatsRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1503,11 +1536,9 @@ async def sample_search_model_monitoring_stats(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_model_monitoring_stats, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_model_monitoring_stats + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1607,8 +1638,8 @@ async def sample_search_model_monitoring_alerts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_monitor]) if request is not None and has_flattened_params: raise ValueError( @@ -1616,7 +1647,14 @@ async def sample_search_model_monitoring_alerts(): "the individual field arguments should be set." ) - request = model_monitoring_service.SearchModelMonitoringAlertsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_monitoring_service.SearchModelMonitoringAlertsRequest + ): + request = model_monitoring_service.SearchModelMonitoringAlertsRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1625,11 +1663,9 @@ async def sample_search_model_monitoring_alerts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.search_model_monitoring_alerts, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.search_model_monitoring_alerts + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/client.py index 509fca20d8..d099ab2711 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -688,7 +689,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ModelMonitoringServiceTransport]] = None, + transport: Optional[ + Union[ + str, + ModelMonitoringServiceTransport, + Callable[..., ModelMonitoringServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -700,9 +707,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ModelMonitoringServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelMonitoringServiceTransport,Callable[..., ModelMonitoringServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelMonitoringServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -814,8 +823,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ModelMonitoringServiceTransport], + Callable[..., ModelMonitoringServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ModelMonitoringServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -905,8 +922,8 @@ def sample_create_model_monitor(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_monitor]) if request is not None and has_flattened_params: raise ValueError( @@ -914,10 +931,8 @@ def sample_create_model_monitor(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.CreateModelMonitorRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_monitoring_service.CreateModelMonitorRequest): request = model_monitoring_service.CreateModelMonitorRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1038,8 +1053,8 @@ def sample_update_model_monitor(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_monitor, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1047,10 +1062,8 @@ def sample_update_model_monitor(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.UpdateModelMonitorRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_monitoring_service.UpdateModelMonitorRequest): request = model_monitoring_service.UpdateModelMonitorRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1162,8 +1175,8 @@ def sample_get_model_monitor(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1171,10 +1184,8 @@ def sample_get_model_monitor(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.GetModelMonitorRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_monitoring_service.GetModelMonitorRequest): request = model_monitoring_service.GetModelMonitorRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1274,8 +1285,8 @@ def sample_list_model_monitors(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1283,10 +1294,8 @@ def sample_list_model_monitors(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.ListModelMonitorsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_monitoring_service.ListModelMonitorsRequest): request = model_monitoring_service.ListModelMonitorsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1405,8 +1414,8 @@ def sample_delete_model_monitor(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1414,10 +1423,8 @@ def sample_delete_model_monitor(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.DeleteModelMonitorRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_monitoring_service.DeleteModelMonitorRequest): request = model_monitoring_service.DeleteModelMonitorRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1531,8 +1538,8 @@ def sample_create_model_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_monitoring_job]) if request is not None and has_flattened_params: raise ValueError( @@ -1540,10 +1547,8 @@ def sample_create_model_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.CreateModelMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_monitoring_service.CreateModelMonitoringJobRequest ): @@ -1646,8 +1651,8 @@ def sample_get_model_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1655,10 +1660,8 @@ def sample_get_model_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.GetModelMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_monitoring_service.GetModelMonitoringJobRequest ): @@ -1764,8 +1767,8 @@ def sample_list_model_monitoring_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1773,10 +1776,8 @@ def sample_list_model_monitoring_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.ListModelMonitoringJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_monitoring_service.ListModelMonitoringJobsRequest ): @@ -1899,8 +1900,8 @@ def sample_delete_model_monitoring_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1908,10 +1909,8 @@ def sample_delete_model_monitoring_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.DeleteModelMonitoringJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_monitoring_service.DeleteModelMonitoringJobRequest ): @@ -2023,8 +2022,8 @@ def sample_search_model_monitoring_stats(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_monitor]) if request is not None and has_flattened_params: raise ValueError( @@ -2032,10 +2031,8 @@ def sample_search_model_monitoring_stats(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.SearchModelMonitoringStatsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_monitoring_service.SearchModelMonitoringStatsRequest ): @@ -2151,8 +2148,8 @@ def sample_search_model_monitoring_alerts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model_monitor]) if request is not None and has_flattened_params: raise ValueError( @@ -2160,10 +2157,8 @@ def sample_search_model_monitoring_alerts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_monitoring_service.SearchModelMonitoringAlertsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_monitoring_service.SearchModelMonitoringAlertsRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/grpc.py index 6340db5548..3cd7493f5f 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/grpc.py @@ -62,7 +62,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -82,14 +82,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -99,11 +102,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -130,7 +133,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -171,7 +174,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/grpc_asyncio.py index ceeb91f21c..413052ae38 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -77,7 +79,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -107,7 +108,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -127,15 +128,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -145,11 +149,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -176,7 +180,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -216,7 +220,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -590,6 +596,66 @@ def search_model_monitoring_alerts( ) return self._stubs["search_model_monitoring_alerts"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_model_monitor: gapic_v1.method_async.wrap_method( + self.create_model_monitor, + default_timeout=None, + client_info=client_info, + ), + self.update_model_monitor: gapic_v1.method_async.wrap_method( + self.update_model_monitor, + default_timeout=None, + client_info=client_info, + ), + self.get_model_monitor: gapic_v1.method_async.wrap_method( + self.get_model_monitor, + default_timeout=None, + client_info=client_info, + ), + self.list_model_monitors: gapic_v1.method_async.wrap_method( + self.list_model_monitors, + default_timeout=None, + client_info=client_info, + ), + self.delete_model_monitor: gapic_v1.method_async.wrap_method( + self.delete_model_monitor, + default_timeout=None, + client_info=client_info, + ), + self.create_model_monitoring_job: gapic_v1.method_async.wrap_method( + self.create_model_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.get_model_monitoring_job: gapic_v1.method_async.wrap_method( + self.get_model_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.list_model_monitoring_jobs: gapic_v1.method_async.wrap_method( + self.list_model_monitoring_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_model_monitoring_job: gapic_v1.method_async.wrap_method( + self.delete_model_monitoring_job, + default_timeout=None, + client_info=client_info, + ), + self.search_model_monitoring_stats: gapic_v1.method_async.wrap_method( + self.search_model_monitoring_stats, + default_timeout=None, + client_info=client_info, + ), + self.search_model_monitoring_alerts: gapic_v1.method_async.wrap_method( + self.search_model_monitoring_alerts, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/rest.py index 1b03c8763b..d6f96a8157 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/transports/rest.py @@ -1001,10 +1001,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1375,10 +1371,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1737,10 +1729,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2115,10 +2103,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2493,10 +2477,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4568,10 +4548,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4999,10 +4975,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -5421,10 +5393,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5860,10 +5828,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -6299,10 +6263,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index cdffd81bc4..3818ddb515 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -227,7 +228,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, ModelServiceTransport, Callable[..., ModelServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -239,9 +242,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ModelServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelServiceTransport,Callable[..., ModelServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -367,8 +372,8 @@ async def sample_upload_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: raise ValueError( @@ -376,7 +381,10 @@ async def sample_upload_model(): "the individual field arguments should be set." ) - request = model_service.UploadModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.UploadModelRequest): + request = model_service.UploadModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -387,11 +395,9 @@ async def sample_upload_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.upload_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.upload_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -491,8 +497,8 @@ async def sample_get_model(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -500,7 +506,10 @@ async def sample_get_model(): "the individual field arguments should be set." ) - request = model_service.GetModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -509,11 +518,9 @@ async def sample_get_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -601,8 +608,8 @@ async def sample_list_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -610,7 +617,10 @@ async def sample_list_models(): "the individual field arguments should be set." ) - request = model_service.ListModelsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -619,11 +629,9 @@ async def sample_list_models(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_models, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_models + ] # Certain fields should be provided within the metadata header; # add these here. @@ -719,8 +727,8 @@ async def sample_list_model_versions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -728,7 +736,10 @@ async def sample_list_model_versions(): "the individual field arguments should be set." ) - request = model_service.ListModelVersionsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelVersionsRequest): + request = model_service.ListModelVersionsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -737,11 +748,9 @@ async def sample_list_model_versions(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_versions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_versions + ] # Certain fields should be provided within the metadata header; # add these here. @@ -864,8 +873,8 @@ async def sample_update_model(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -873,7 +882,10 @@ async def sample_update_model(): "the individual field arguments should be set." ) - request = model_service.UpdateModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.UpdateModelRequest): + request = model_service.UpdateModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -884,11 +896,9 @@ async def sample_update_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -984,8 +994,8 @@ async def sample_update_explanation_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model]) if request is not None and has_flattened_params: raise ValueError( @@ -993,7 +1003,10 @@ async def sample_update_explanation_dataset(): "the individual field arguments should be set." ) - request = model_service.UpdateExplanationDatasetRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.UpdateExplanationDatasetRequest): + request = model_service.UpdateExplanationDatasetRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1002,11 +1015,9 @@ async def sample_update_explanation_dataset(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_explanation_dataset, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_explanation_dataset + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1120,8 +1131,8 @@ async def sample_delete_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1129,7 +1140,10 @@ async def sample_delete_model(): "the individual field arguments should be set." ) - request = model_service.DeleteModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.DeleteModelRequest): + request = model_service.DeleteModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1138,11 +1152,9 @@ async def sample_delete_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1257,8 +1269,8 @@ async def sample_delete_model_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1266,7 +1278,10 @@ async def sample_delete_model_version(): "the individual field arguments should be set." ) - request = model_service.DeleteModelVersionRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.DeleteModelVersionRequest): + request = model_service.DeleteModelVersionRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1275,11 +1290,9 @@ async def sample_delete_model_version(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model_version, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_model_version + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1393,8 +1406,8 @@ async def sample_merge_version_aliases(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, version_aliases]) if request is not None and has_flattened_params: raise ValueError( @@ -1402,7 +1415,10 @@ async def sample_merge_version_aliases(): "the individual field arguments should be set." ) - request = model_service.MergeVersionAliasesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.MergeVersionAliasesRequest): + request = model_service.MergeVersionAliasesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1413,11 +1429,9 @@ async def sample_merge_version_aliases(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.merge_version_aliases, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.merge_version_aliases + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1522,8 +1536,8 @@ async def sample_export_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: raise ValueError( @@ -1531,7 +1545,10 @@ async def sample_export_model(): "the individual field arguments should be set." ) - request = model_service.ExportModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ExportModelRequest): + request = model_service.ExportModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1542,11 +1559,9 @@ async def sample_export_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1661,8 +1676,8 @@ async def sample_copy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, source_model]) if request is not None and has_flattened_params: raise ValueError( @@ -1670,7 +1685,10 @@ async def sample_copy_model(): "the individual field arguments should be set." ) - request = model_service.CopyModelRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.CopyModelRequest): + request = model_service.CopyModelRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1681,11 +1699,9 @@ async def sample_copy_model(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.copy_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.copy_model + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1788,8 +1804,8 @@ async def sample_import_model_evaluation(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_evaluation]) if request is not None and has_flattened_params: raise ValueError( @@ -1797,7 +1813,10 @@ async def sample_import_model_evaluation(): "the individual field arguments should be set." ) - request = model_service.ImportModelEvaluationRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ImportModelEvaluationRequest): + request = model_service.ImportModelEvaluationRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1808,11 +1827,9 @@ async def sample_import_model_evaluation(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_model_evaluation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_model_evaluation + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1909,8 +1926,8 @@ async def sample_batch_import_model_evaluation_slices(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_evaluation_slices]) if request is not None and has_flattened_params: raise ValueError( @@ -1918,7 +1935,12 @@ async def sample_batch_import_model_evaluation_slices(): "the individual field arguments should be set." ) - request = model_service.BatchImportModelEvaluationSlicesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_service.BatchImportModelEvaluationSlicesRequest + ): + request = model_service.BatchImportModelEvaluationSlicesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1929,11 +1951,9 @@ async def sample_batch_import_model_evaluation_slices(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_import_model_evaluation_slices, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_import_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2030,8 +2050,8 @@ async def sample_batch_import_evaluated_annotations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, evaluated_annotations]) if request is not None and has_flattened_params: raise ValueError( @@ -2039,7 +2059,12 @@ async def sample_batch_import_evaluated_annotations(): "the individual field arguments should be set." ) - request = model_service.BatchImportEvaluatedAnnotationsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, model_service.BatchImportEvaluatedAnnotationsRequest + ): + request = model_service.BatchImportEvaluatedAnnotationsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2050,11 +2075,9 @@ async def sample_batch_import_evaluated_annotations(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_import_evaluated_annotations, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_import_evaluated_annotations + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2140,8 +2163,8 @@ async def sample_get_model_evaluation(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2149,7 +2172,10 @@ async def sample_get_model_evaluation(): "the individual field arguments should be set." ) - request = model_service.GetModelEvaluationRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetModelEvaluationRequest): + request = model_service.GetModelEvaluationRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2158,11 +2184,9 @@ async def sample_get_model_evaluation(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_evaluation, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model_evaluation + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2252,8 +2276,8 @@ async def sample_list_model_evaluations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2261,7 +2285,10 @@ async def sample_list_model_evaluations(): "the individual field arguments should be set." ) - request = model_service.ListModelEvaluationsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelEvaluationsRequest): + request = model_service.ListModelEvaluationsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2270,11 +2297,9 @@ async def sample_list_model_evaluations(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_evaluations, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_evaluations + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2371,8 +2396,8 @@ async def sample_get_model_evaluation_slice(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2380,7 +2405,10 @@ async def sample_get_model_evaluation_slice(): "the individual field arguments should be set." ) - request = model_service.GetModelEvaluationSliceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetModelEvaluationSliceRequest): + request = model_service.GetModelEvaluationSliceRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2389,11 +2417,9 @@ async def sample_get_model_evaluation_slice(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_evaluation_slice, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model_evaluation_slice + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2483,8 +2509,8 @@ async def sample_list_model_evaluation_slices(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2492,7 +2518,10 @@ async def sample_list_model_evaluation_slices(): "the individual field arguments should be set." ) - request = model_service.ListModelEvaluationSlicesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelEvaluationSlicesRequest): + request = model_service.ListModelEvaluationSlicesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2501,11 +2530,9 @@ async def sample_list_model_evaluation_slices(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_evaluation_slices, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 5e281f7715..55bd9724e5 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -642,7 +643,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ModelServiceTransport]] = None, + transport: Optional[ + Union[str, ModelServiceTransport, Callable[..., ModelServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -654,9 +657,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ModelServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ModelServiceTransport,Callable[..., ModelServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -765,8 +770,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ModelServiceTransport], Callable[..., ModelServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ModelServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -857,8 +869,8 @@ def sample_upload_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: raise ValueError( @@ -866,10 +878,8 @@ def sample_upload_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.UploadModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.UploadModelRequest): request = model_service.UploadModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -981,8 +991,8 @@ def sample_get_model(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -990,10 +1000,8 @@ def sample_get_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.GetModelRequest): request = model_service.GetModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1091,8 +1099,8 @@ def sample_list_models(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1100,10 +1108,8 @@ def sample_list_models(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ListModelsRequest): request = model_service.ListModelsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1209,8 +1215,8 @@ def sample_list_model_versions(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1218,10 +1224,8 @@ def sample_list_model_versions(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelVersionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ListModelVersionsRequest): request = model_service.ListModelVersionsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1354,8 +1358,8 @@ def sample_update_model(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1363,10 +1367,8 @@ def sample_update_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.UpdateModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.UpdateModelRequest): request = model_service.UpdateModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1474,8 +1476,8 @@ def sample_update_explanation_dataset(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model]) if request is not None and has_flattened_params: raise ValueError( @@ -1483,10 +1485,8 @@ def sample_update_explanation_dataset(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.UpdateExplanationDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.UpdateExplanationDatasetRequest): request = model_service.UpdateExplanationDatasetRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1612,8 +1612,8 @@ def sample_delete_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1621,10 +1621,8 @@ def sample_delete_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.DeleteModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.DeleteModelRequest): request = model_service.DeleteModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1749,8 +1747,8 @@ def sample_delete_model_version(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1758,10 +1756,8 @@ def sample_delete_model_version(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.DeleteModelVersionRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.DeleteModelVersionRequest): request = model_service.DeleteModelVersionRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1885,8 +1881,8 @@ def sample_merge_version_aliases(): A trained machine learning Model. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, version_aliases]) if request is not None and has_flattened_params: raise ValueError( @@ -1894,10 +1890,8 @@ def sample_merge_version_aliases(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.MergeVersionAliasesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.MergeVersionAliasesRequest): request = model_service.MergeVersionAliasesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2014,8 +2008,8 @@ def sample_export_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: raise ValueError( @@ -2023,10 +2017,8 @@ def sample_export_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ExportModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ExportModelRequest): request = model_service.ExportModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2153,8 +2145,8 @@ def sample_copy_model(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, source_model]) if request is not None and has_flattened_params: raise ValueError( @@ -2162,10 +2154,8 @@ def sample_copy_model(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.CopyModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.CopyModelRequest): request = model_service.CopyModelRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2280,8 +2270,8 @@ def sample_import_model_evaluation(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_evaluation]) if request is not None and has_flattened_params: raise ValueError( @@ -2289,10 +2279,8 @@ def sample_import_model_evaluation(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ImportModelEvaluationRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ImportModelEvaluationRequest): request = model_service.ImportModelEvaluationRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2401,8 +2389,8 @@ def sample_batch_import_model_evaluation_slices(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model_evaluation_slices]) if request is not None and has_flattened_params: raise ValueError( @@ -2410,10 +2398,8 @@ def sample_batch_import_model_evaluation_slices(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.BatchImportModelEvaluationSlicesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_service.BatchImportModelEvaluationSlicesRequest ): @@ -2526,8 +2512,8 @@ def sample_batch_import_evaluated_annotations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, evaluated_annotations]) if request is not None and has_flattened_params: raise ValueError( @@ -2535,10 +2521,8 @@ def sample_batch_import_evaluated_annotations(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.BatchImportEvaluatedAnnotationsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, model_service.BatchImportEvaluatedAnnotationsRequest ): @@ -2640,8 +2624,8 @@ def sample_get_model_evaluation(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2649,10 +2633,8 @@ def sample_get_model_evaluation(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelEvaluationRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.GetModelEvaluationRequest): request = model_service.GetModelEvaluationRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2752,8 +2734,8 @@ def sample_list_model_evaluations(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2761,10 +2743,8 @@ def sample_list_model_evaluations(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelEvaluationsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ListModelEvaluationsRequest): request = model_service.ListModelEvaluationsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2871,8 +2851,8 @@ def sample_get_model_evaluation_slice(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2880,10 +2860,8 @@ def sample_get_model_evaluation_slice(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelEvaluationSliceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.GetModelEvaluationSliceRequest): request = model_service.GetModelEvaluationSliceRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2985,8 +2963,8 @@ def sample_list_model_evaluation_slices(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2994,10 +2972,8 @@ def sample_list_model_evaluation_slices(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelEvaluationSlicesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, model_service.ListModelEvaluationSlicesRequest): request = model_service.ListModelEvaluationSlicesRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index 662d1f0a71..b2d85c7dd1 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -62,7 +62,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -82,14 +82,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -99,11 +102,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -130,7 +133,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -171,7 +174,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py index af85d4ae89..f0f9011b60 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -77,7 +79,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -107,7 +108,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -127,15 +128,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -145,11 +149,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -176,7 +180,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -216,7 +220,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -800,6 +806,101 @@ def list_model_evaluation_slices( ) return self._stubs["list_model_evaluation_slices"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.upload_model: gapic_v1.method_async.wrap_method( + self.upload_model, + default_timeout=5.0, + client_info=client_info, + ), + self.get_model: gapic_v1.method_async.wrap_method( + self.get_model, + default_timeout=5.0, + client_info=client_info, + ), + self.list_models: gapic_v1.method_async.wrap_method( + self.list_models, + default_timeout=5.0, + client_info=client_info, + ), + self.list_model_versions: gapic_v1.method_async.wrap_method( + self.list_model_versions, + default_timeout=None, + client_info=client_info, + ), + self.update_model: gapic_v1.method_async.wrap_method( + self.update_model, + default_timeout=5.0, + client_info=client_info, + ), + self.update_explanation_dataset: gapic_v1.method_async.wrap_method( + self.update_explanation_dataset, + default_timeout=None, + client_info=client_info, + ), + self.delete_model: gapic_v1.method_async.wrap_method( + self.delete_model, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_model_version: gapic_v1.method_async.wrap_method( + self.delete_model_version, + default_timeout=None, + client_info=client_info, + ), + self.merge_version_aliases: gapic_v1.method_async.wrap_method( + self.merge_version_aliases, + default_timeout=None, + client_info=client_info, + ), + self.export_model: gapic_v1.method_async.wrap_method( + self.export_model, + default_timeout=5.0, + client_info=client_info, + ), + self.copy_model: gapic_v1.method_async.wrap_method( + self.copy_model, + default_timeout=5.0, + client_info=client_info, + ), + self.import_model_evaluation: gapic_v1.method_async.wrap_method( + self.import_model_evaluation, + default_timeout=None, + client_info=client_info, + ), + self.batch_import_model_evaluation_slices: gapic_v1.method_async.wrap_method( + self.batch_import_model_evaluation_slices, + default_timeout=None, + client_info=client_info, + ), + self.batch_import_evaluated_annotations: gapic_v1.method_async.wrap_method( + self.batch_import_evaluated_annotations, + default_timeout=None, + client_info=client_info, + ), + self.get_model_evaluation: gapic_v1.method_async.wrap_method( + self.get_model_evaluation, + default_timeout=5.0, + client_info=client_info, + ), + self.list_model_evaluations: gapic_v1.method_async.wrap_method( + self.list_model_evaluations, + default_timeout=5.0, + client_info=client_info, + ), + self.get_model_evaluation_slice: gapic_v1.method_async.wrap_method( + self.get_model_evaluation_slice, + default_timeout=5.0, + client_info=client_info, + ), + self.list_model_evaluation_slices: gapic_v1.method_async.wrap_method( + self.list_model_evaluation_slices, + default_timeout=5.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/rest.py index 0848c65b34..8a17cd73a6 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/rest.py @@ -1190,10 +1190,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1564,10 +1560,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1926,10 +1918,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2304,10 +2292,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2682,10 +2666,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -5422,10 +5402,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -5853,10 +5829,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -6275,10 +6247,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -6714,10 +6682,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -7153,10 +7117,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/async_client.py index 8bb9d123b5..5f0cdc9b32 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -45,9 +46,11 @@ from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.notebook_service import pagers +from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import network_spec from google.cloud.aiplatform_v1beta1.types import notebook_euc_config +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_idle_shutdown_config from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import ( @@ -60,8 +63,10 @@ from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore from .transports.base import NotebookServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import NotebookServiceGrpcAsyncIOTransport from .client import NotebookServiceClient @@ -83,6 +88,12 @@ class NotebookServiceAsyncClient: network_path = staticmethod(NotebookServiceClient.network_path) parse_network_path = staticmethod(NotebookServiceClient.parse_network_path) + notebook_execution_job_path = staticmethod( + NotebookServiceClient.notebook_execution_job_path + ) + parse_notebook_execution_job_path = staticmethod( + NotebookServiceClient.parse_notebook_execution_job_path + ) notebook_runtime_path = staticmethod(NotebookServiceClient.notebook_runtime_path) parse_notebook_runtime_path = staticmethod( NotebookServiceClient.parse_notebook_runtime_path @@ -93,6 +104,8 @@ class NotebookServiceAsyncClient: parse_notebook_runtime_template_path = staticmethod( NotebookServiceClient.parse_notebook_runtime_template_path ) + schedule_path = staticmethod(NotebookServiceClient.schedule_path) + parse_schedule_path = staticmethod(NotebookServiceClient.parse_schedule_path) subnetwork_path = staticmethod(NotebookServiceClient.subnetwork_path) parse_subnetwork_path = staticmethod(NotebookServiceClient.parse_subnetwork_path) common_billing_account_path = staticmethod( @@ -225,7 +238,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, NotebookServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, NotebookServiceTransport, Callable[..., NotebookServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -237,9 +254,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.NotebookServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,NotebookServiceTransport,Callable[..., NotebookServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the NotebookServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -380,8 +399,8 @@ async def sample_create_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, notebook_runtime_template, notebook_runtime_template_id] ) @@ -391,7 +410,12 @@ async def sample_create_notebook_runtime_template(): "the individual field arguments should be set." ) - request = notebook_service.CreateNotebookRuntimeTemplateRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, notebook_service.CreateNotebookRuntimeTemplateRequest + ): + request = notebook_service.CreateNotebookRuntimeTemplateRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -404,11 +428,9 @@ async def sample_create_notebook_runtime_template(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_notebook_runtime_template, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_notebook_runtime_template + ] # Certain fields should be provided within the metadata header; # add these here. @@ -505,8 +527,8 @@ async def sample_get_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -514,7 +536,10 @@ async def sample_get_notebook_runtime_template(): "the individual field arguments should be set." ) - request = notebook_service.GetNotebookRuntimeTemplateRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.GetNotebookRuntimeTemplateRequest): + request = notebook_service.GetNotebookRuntimeTemplateRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -523,11 +548,9 @@ async def sample_get_notebook_runtime_template(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_notebook_runtime_template, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_notebook_runtime_template + ] # Certain fields should be provided within the metadata header; # add these here. @@ -617,8 +640,8 @@ async def sample_list_notebook_runtime_templates(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -626,7 +649,12 @@ async def sample_list_notebook_runtime_templates(): "the individual field arguments should be set." ) - request = notebook_service.ListNotebookRuntimeTemplatesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, notebook_service.ListNotebookRuntimeTemplatesRequest + ): + request = notebook_service.ListNotebookRuntimeTemplatesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -635,11 +663,9 @@ async def sample_list_notebook_runtime_templates(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_notebook_runtime_templates, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_notebook_runtime_templates + ] # Certain fields should be provided within the metadata header; # add these here. @@ -748,8 +774,8 @@ async def sample_delete_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -757,7 +783,12 @@ async def sample_delete_notebook_runtime_template(): "the individual field arguments should be set." ) - request = notebook_service.DeleteNotebookRuntimeTemplateRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, notebook_service.DeleteNotebookRuntimeTemplateRequest + ): + request = notebook_service.DeleteNotebookRuntimeTemplateRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -766,11 +797,9 @@ async def sample_delete_notebook_runtime_template(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_notebook_runtime_template, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_notebook_runtime_template + ] # Certain fields should be provided within the metadata header; # add these here. @@ -907,8 +936,8 @@ async def sample_assign_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, notebook_runtime_template, notebook_runtime, notebook_runtime_id] ) @@ -918,7 +947,10 @@ async def sample_assign_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.AssignNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.AssignNotebookRuntimeRequest): + request = notebook_service.AssignNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -933,11 +965,9 @@ async def sample_assign_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.assign_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.assign_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1037,8 +1067,8 @@ async def sample_get_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1046,7 +1076,10 @@ async def sample_get_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.GetNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.GetNotebookRuntimeRequest): + request = notebook_service.GetNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1055,11 +1088,9 @@ async def sample_get_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1149,8 +1180,8 @@ async def sample_list_notebook_runtimes(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1158,7 +1189,10 @@ async def sample_list_notebook_runtimes(): "the individual field arguments should be set." ) - request = notebook_service.ListNotebookRuntimesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.ListNotebookRuntimesRequest): + request = notebook_service.ListNotebookRuntimesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1167,11 +1201,9 @@ async def sample_list_notebook_runtimes(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_notebook_runtimes, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_notebook_runtimes + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1284,8 +1316,8 @@ async def sample_delete_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1293,7 +1325,10 @@ async def sample_delete_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.DeleteNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.DeleteNotebookRuntimeRequest): + request = notebook_service.DeleteNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1302,11 +1337,9 @@ async def sample_delete_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1410,8 +1443,8 @@ async def sample_upgrade_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1419,7 +1452,10 @@ async def sample_upgrade_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.UpgradeNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.UpgradeNotebookRuntimeRequest): + request = notebook_service.UpgradeNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1428,11 +1464,9 @@ async def sample_upgrade_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.upgrade_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.upgrade_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1536,8 +1570,8 @@ async def sample_start_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1545,7 +1579,10 @@ async def sample_start_notebook_runtime(): "the individual field arguments should be set." ) - request = notebook_service.StartNotebookRuntimeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.StartNotebookRuntimeRequest): + request = notebook_service.StartNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1554,11 +1591,9 @@ async def sample_start_notebook_runtime(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.start_notebook_runtime, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.start_notebook_runtime + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1588,6 +1623,367 @@ async def sample_start_notebook_runtime(): # Done; return the response. return response + async def get_notebook_execution_job( + self, + request: Optional[ + Union[notebook_service.GetNotebookExecutionJobRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> notebook_execution_job.NotebookExecutionJob: + r"""Gets a NotebookExecutionJob. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + async def sample_get_notebook_execution_job(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.GetNotebookExecutionJobRequest( + name="name_value", + ) + + # Make the request + response = await client.get_notebook_execution_job(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1beta1.types.GetNotebookExecutionJobRequest, dict]]): + The request object. Request message for + [NotebookService.GetNotebookExecutionJob] + name (:class:`str`): + Required. The name of the + NotebookExecutionJob resource. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.NotebookExecutionJob: + NotebookExecutionJob represents an + instance of a notebook execution. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.GetNotebookExecutionJobRequest): + request = notebook_service.GetNotebookExecutionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_notebook_execution_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_notebook_execution_jobs( + self, + request: Optional[ + Union[notebook_service.ListNotebookExecutionJobsRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListNotebookExecutionJobsAsyncPager: + r"""Lists NotebookExecutionJobs in a Location. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + async def sample_list_notebook_execution_jobs(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.ListNotebookExecutionJobsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_notebook_execution_jobs(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsRequest, dict]]): + The request object. Request message for + [NotebookService.ListNotebookExecutionJobs] + parent (:class:`str`): + Required. The resource name of the Location from which + to list the NotebookExecutionJobs. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.notebook_service.pagers.ListNotebookExecutionJobsAsyncPager: + Response message for + [NotebookService.CreateNotebookExecutionJob] + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.ListNotebookExecutionJobsRequest): + request = notebook_service.ListNotebookExecutionJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_notebook_execution_jobs + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListNotebookExecutionJobsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_notebook_execution_job( + self, + request: Optional[ + Union[notebook_service.DeleteNotebookExecutionJobRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a NotebookExecutionJob. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + async def sample_delete_notebook_execution_job(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.DeleteNotebookExecutionJobRequest( + name="name_value", + ) + + # Make the request + operation = client.delete_notebook_execution_job(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1beta1.types.DeleteNotebookExecutionJobRequest, dict]]): + The request object. Request message for + [NotebookService.DeleteNotebookExecutionJob] + name (:class:`str`): + Required. The name of the + NotebookExecutionJob resource to be + deleted. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.DeleteNotebookExecutionJobRequest): + request = notebook_service.DeleteNotebookExecutionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_notebook_execution_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty_pb2.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + async def list_operations( self, request: Optional[operations_pb2.ListOperationsRequest] = None, diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/client.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/client.py index a0ff87a2f0..4785b2aead 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -50,9 +51,11 @@ from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.notebook_service import pagers +from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import network_spec from google.cloud.aiplatform_v1beta1.types import notebook_euc_config +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_idle_shutdown_config from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import ( @@ -65,8 +68,10 @@ from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore from .transports.base import NotebookServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import NotebookServiceGrpcTransport from .transports.grpc_asyncio import NotebookServiceGrpcAsyncIOTransport @@ -220,6 +225,28 @@ def parse_network_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def notebook_execution_job_path( + project: str, + location: str, + notebook_execution_job: str, + ) -> str: + """Returns a fully-qualified notebook_execution_job string.""" + return "projects/{project}/locations/{location}/notebookExecutionJobs/{notebook_execution_job}".format( + project=project, + location=location, + notebook_execution_job=notebook_execution_job, + ) + + @staticmethod + def parse_notebook_execution_job_path(path: str) -> Dict[str, str]: + """Parses a notebook_execution_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/notebookExecutionJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def notebook_runtime_path( project: str, @@ -264,6 +291,28 @@ def parse_notebook_runtime_template_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def schedule_path( + project: str, + location: str, + schedule: str, + ) -> str: + """Returns a fully-qualified schedule string.""" + return "projects/{project}/locations/{location}/schedules/{schedule}".format( + project=project, + location=location, + schedule=schedule, + ) + + @staticmethod + def parse_schedule_path(path: str) -> Dict[str, str]: + """Parses a schedule path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/schedules/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def subnetwork_path( project: str, @@ -611,7 +660,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, NotebookServiceTransport]] = None, + transport: Optional[ + Union[ + str, NotebookServiceTransport, Callable[..., NotebookServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -623,9 +676,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, NotebookServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,NotebookServiceTransport,Callable[..., NotebookServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the NotebookServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -737,8 +792,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[NotebookServiceTransport], Callable[..., NotebookServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., NotebookServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -844,8 +906,8 @@ def sample_create_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, notebook_runtime_template, notebook_runtime_template_id] ) @@ -855,10 +917,8 @@ def sample_create_notebook_runtime_template(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.CreateNotebookRuntimeTemplateRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, notebook_service.CreateNotebookRuntimeTemplateRequest ): @@ -973,8 +1033,8 @@ def sample_get_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -982,10 +1042,8 @@ def sample_get_notebook_runtime_template(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.GetNotebookRuntimeTemplateRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.GetNotebookRuntimeTemplateRequest): request = notebook_service.GetNotebookRuntimeTemplateRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1087,8 +1145,8 @@ def sample_list_notebook_runtime_templates(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1096,10 +1154,8 @@ def sample_list_notebook_runtime_templates(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.ListNotebookRuntimeTemplatesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, notebook_service.ListNotebookRuntimeTemplatesRequest ): @@ -1222,8 +1278,8 @@ def sample_delete_notebook_runtime_template(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1231,10 +1287,8 @@ def sample_delete_notebook_runtime_template(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.DeleteNotebookRuntimeTemplateRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, notebook_service.DeleteNotebookRuntimeTemplateRequest ): @@ -1385,8 +1439,8 @@ def sample_assign_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, notebook_runtime_template, notebook_runtime, notebook_runtime_id] ) @@ -1396,10 +1450,8 @@ def sample_assign_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.AssignNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.AssignNotebookRuntimeRequest): request = notebook_service.AssignNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1515,8 +1567,8 @@ def sample_get_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1524,10 +1576,8 @@ def sample_get_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.GetNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.GetNotebookRuntimeRequest): request = notebook_service.GetNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1627,8 +1677,8 @@ def sample_list_notebook_runtimes(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1636,10 +1686,8 @@ def sample_list_notebook_runtimes(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.ListNotebookRuntimesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.ListNotebookRuntimesRequest): request = notebook_service.ListNotebookRuntimesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1762,8 +1810,8 @@ def sample_delete_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1771,10 +1819,8 @@ def sample_delete_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.DeleteNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.DeleteNotebookRuntimeRequest): request = notebook_service.DeleteNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1888,8 +1934,8 @@ def sample_upgrade_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1897,10 +1943,8 @@ def sample_upgrade_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.UpgradeNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.UpgradeNotebookRuntimeRequest): request = notebook_service.UpgradeNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2014,8 +2058,8 @@ def sample_start_notebook_runtime(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2023,10 +2067,8 @@ def sample_start_notebook_runtime(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a notebook_service.StartNotebookRuntimeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, notebook_service.StartNotebookRuntimeRequest): request = notebook_service.StartNotebookRuntimeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2066,6 +2108,364 @@ def sample_start_notebook_runtime(): # Done; return the response. return response + def get_notebook_execution_job( + self, + request: Optional[ + Union[notebook_service.GetNotebookExecutionJobRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> notebook_execution_job.NotebookExecutionJob: + r"""Gets a NotebookExecutionJob. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_get_notebook_execution_job(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.GetNotebookExecutionJobRequest( + name="name_value", + ) + + # Make the request + response = client.get_notebook_execution_job(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.GetNotebookExecutionJobRequest, dict]): + The request object. Request message for + [NotebookService.GetNotebookExecutionJob] + name (str): + Required. The name of the + NotebookExecutionJob resource. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.NotebookExecutionJob: + NotebookExecutionJob represents an + instance of a notebook execution. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.GetNotebookExecutionJobRequest): + request = notebook_service.GetNotebookExecutionJobRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.get_notebook_execution_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_notebook_execution_jobs( + self, + request: Optional[ + Union[notebook_service.ListNotebookExecutionJobsRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListNotebookExecutionJobsPager: + r"""Lists NotebookExecutionJobs in a Location. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_list_notebook_execution_jobs(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.ListNotebookExecutionJobsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_notebook_execution_jobs(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsRequest, dict]): + The request object. Request message for + [NotebookService.ListNotebookExecutionJobs] + parent (str): + Required. The resource name of the Location from which + to list the NotebookExecutionJobs. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.notebook_service.pagers.ListNotebookExecutionJobsPager: + Response message for + [NotebookService.CreateNotebookExecutionJob] + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.ListNotebookExecutionJobsRequest): + request = notebook_service.ListNotebookExecutionJobsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.list_notebook_execution_jobs + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListNotebookExecutionJobsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_notebook_execution_job( + self, + request: Optional[ + Union[notebook_service.DeleteNotebookExecutionJobRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a NotebookExecutionJob. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_delete_notebook_execution_job(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.DeleteNotebookExecutionJobRequest( + name="name_value", + ) + + # Make the request + operation = client.delete_notebook_execution_job(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.DeleteNotebookExecutionJobRequest, dict]): + The request object. Request message for + [NotebookService.DeleteNotebookExecutionJob] + name (str): + Required. The name of the + NotebookExecutionJob resource to be + deleted. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, notebook_service.DeleteNotebookExecutionJobRequest): + request = notebook_service.DeleteNotebookExecutionJobRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.delete_notebook_execution_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty_pb2.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "NotebookServiceClient": return self diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/pagers.py index f6e467a90f..1f9a2dc346 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/pagers.py @@ -24,6 +24,7 @@ Iterator, ) +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import notebook_service @@ -288,3 +289,135 @@ async def async_generator(): def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListNotebookExecutionJobsPager: + """A pager for iterating through ``list_notebook_execution_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``notebook_execution_jobs`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListNotebookExecutionJobs`` requests and continue to iterate + through the ``notebook_execution_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., notebook_service.ListNotebookExecutionJobsResponse], + request: notebook_service.ListNotebookExecutionJobsRequest, + response: notebook_service.ListNotebookExecutionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = notebook_service.ListNotebookExecutionJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[notebook_service.ListNotebookExecutionJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[notebook_execution_job.NotebookExecutionJob]: + for page in self.pages: + yield from page.notebook_execution_jobs + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListNotebookExecutionJobsAsyncPager: + """A pager for iterating through ``list_notebook_execution_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``notebook_execution_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListNotebookExecutionJobs`` requests and continue to iterate + through the ``notebook_execution_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[notebook_service.ListNotebookExecutionJobsResponse] + ], + request: notebook_service.ListNotebookExecutionJobsRequest, + response: notebook_service.ListNotebookExecutionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = notebook_service.ListNotebookExecutionJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterator[notebook_service.ListNotebookExecutionJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterator[notebook_execution_job.NotebookExecutionJob]: + async def async_generator(): + async for page in self.pages: + for response in page.notebook_execution_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/base.py index 6772732ffa..be03009271 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/base.py @@ -27,6 +27,7 @@ from google.auth import credentials as ga_credentials # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import notebook_service from google.cloud.location import locations_pb2 # type: ignore @@ -183,6 +184,21 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.get_notebook_execution_job: gapic_v1.method.wrap_method( + self.get_notebook_execution_job, + default_timeout=None, + client_info=client_info, + ), + self.list_notebook_execution_jobs: gapic_v1.method.wrap_method( + self.list_notebook_execution_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_notebook_execution_job: gapic_v1.method.wrap_method( + self.delete_notebook_execution_job, + default_timeout=None, + client_info=client_info, + ), } def close(self): @@ -301,6 +317,39 @@ def start_notebook_runtime( ]: raise NotImplementedError() + @property + def get_notebook_execution_job( + self, + ) -> Callable[ + [notebook_service.GetNotebookExecutionJobRequest], + Union[ + notebook_execution_job.NotebookExecutionJob, + Awaitable[notebook_execution_job.NotebookExecutionJob], + ], + ]: + raise NotImplementedError() + + @property + def list_notebook_execution_jobs( + self, + ) -> Callable[ + [notebook_service.ListNotebookExecutionJobsRequest], + Union[ + notebook_service.ListNotebookExecutionJobsResponse, + Awaitable[notebook_service.ListNotebookExecutionJobsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_notebook_execution_job( + self, + ) -> Callable[ + [notebook_service.DeleteNotebookExecutionJobRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def list_operations( self, diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/grpc.py index 097e8b9565..be84d42e36 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/grpc.py @@ -25,6 +25,7 @@ import grpc # type: ignore +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import notebook_service from google.cloud.location import locations_pb2 # type: ignore @@ -57,7 +58,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +78,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +98,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +129,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +170,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -549,6 +555,94 @@ def start_notebook_runtime( ) return self._stubs["start_notebook_runtime"] + @property + def get_notebook_execution_job( + self, + ) -> Callable[ + [notebook_service.GetNotebookExecutionJobRequest], + notebook_execution_job.NotebookExecutionJob, + ]: + r"""Return a callable for the get notebook execution job method over gRPC. + + Gets a NotebookExecutionJob. + + Returns: + Callable[[~.GetNotebookExecutionJobRequest], + ~.NotebookExecutionJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_notebook_execution_job" not in self._stubs: + self._stubs["get_notebook_execution_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.NotebookService/GetNotebookExecutionJob", + request_serializer=notebook_service.GetNotebookExecutionJobRequest.serialize, + response_deserializer=notebook_execution_job.NotebookExecutionJob.deserialize, + ) + return self._stubs["get_notebook_execution_job"] + + @property + def list_notebook_execution_jobs( + self, + ) -> Callable[ + [notebook_service.ListNotebookExecutionJobsRequest], + notebook_service.ListNotebookExecutionJobsResponse, + ]: + r"""Return a callable for the list notebook execution jobs method over gRPC. + + Lists NotebookExecutionJobs in a Location. + + Returns: + Callable[[~.ListNotebookExecutionJobsRequest], + ~.ListNotebookExecutionJobsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_notebook_execution_jobs" not in self._stubs: + self._stubs["list_notebook_execution_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.NotebookService/ListNotebookExecutionJobs", + request_serializer=notebook_service.ListNotebookExecutionJobsRequest.serialize, + response_deserializer=notebook_service.ListNotebookExecutionJobsResponse.deserialize, + ) + return self._stubs["list_notebook_execution_jobs"] + + @property + def delete_notebook_execution_job( + self, + ) -> Callable[ + [notebook_service.DeleteNotebookExecutionJobRequest], operations_pb2.Operation + ]: + r"""Return a callable for the delete notebook execution job method over gRPC. + + Deletes a NotebookExecutionJob. + + Returns: + Callable[[~.DeleteNotebookExecutionJobRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_notebook_execution_job" not in self._stubs: + self._stubs[ + "delete_notebook_execution_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.NotebookService/DeleteNotebookExecutionJob", + request_serializer=notebook_service.DeleteNotebookExecutionJobRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_notebook_execution_job"] + def close(self): self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/grpc_asyncio.py index 1f6ff513a7..688e29eac7 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -25,6 +27,7 @@ import grpc # type: ignore from grpc.experimental import aio # type: ignore +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import notebook_service from google.cloud.location import locations_pb2 # type: ignore @@ -72,7 +75,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +104,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +124,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +145,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +176,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +216,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -559,6 +566,165 @@ def start_notebook_runtime( ) return self._stubs["start_notebook_runtime"] + @property + def get_notebook_execution_job( + self, + ) -> Callable[ + [notebook_service.GetNotebookExecutionJobRequest], + Awaitable[notebook_execution_job.NotebookExecutionJob], + ]: + r"""Return a callable for the get notebook execution job method over gRPC. + + Gets a NotebookExecutionJob. + + Returns: + Callable[[~.GetNotebookExecutionJobRequest], + Awaitable[~.NotebookExecutionJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_notebook_execution_job" not in self._stubs: + self._stubs["get_notebook_execution_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.NotebookService/GetNotebookExecutionJob", + request_serializer=notebook_service.GetNotebookExecutionJobRequest.serialize, + response_deserializer=notebook_execution_job.NotebookExecutionJob.deserialize, + ) + return self._stubs["get_notebook_execution_job"] + + @property + def list_notebook_execution_jobs( + self, + ) -> Callable[ + [notebook_service.ListNotebookExecutionJobsRequest], + Awaitable[notebook_service.ListNotebookExecutionJobsResponse], + ]: + r"""Return a callable for the list notebook execution jobs method over gRPC. + + Lists NotebookExecutionJobs in a Location. + + Returns: + Callable[[~.ListNotebookExecutionJobsRequest], + Awaitable[~.ListNotebookExecutionJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_notebook_execution_jobs" not in self._stubs: + self._stubs["list_notebook_execution_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.NotebookService/ListNotebookExecutionJobs", + request_serializer=notebook_service.ListNotebookExecutionJobsRequest.serialize, + response_deserializer=notebook_service.ListNotebookExecutionJobsResponse.deserialize, + ) + return self._stubs["list_notebook_execution_jobs"] + + @property + def delete_notebook_execution_job( + self, + ) -> Callable[ + [notebook_service.DeleteNotebookExecutionJobRequest], + Awaitable[operations_pb2.Operation], + ]: + r"""Return a callable for the delete notebook execution job method over gRPC. + + Deletes a NotebookExecutionJob. + + Returns: + Callable[[~.DeleteNotebookExecutionJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_notebook_execution_job" not in self._stubs: + self._stubs[ + "delete_notebook_execution_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.NotebookService/DeleteNotebookExecutionJob", + request_serializer=notebook_service.DeleteNotebookExecutionJobRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_notebook_execution_job"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_notebook_runtime_template: gapic_v1.method_async.wrap_method( + self.create_notebook_runtime_template, + default_timeout=None, + client_info=client_info, + ), + self.get_notebook_runtime_template: gapic_v1.method_async.wrap_method( + self.get_notebook_runtime_template, + default_timeout=None, + client_info=client_info, + ), + self.list_notebook_runtime_templates: gapic_v1.method_async.wrap_method( + self.list_notebook_runtime_templates, + default_timeout=None, + client_info=client_info, + ), + self.delete_notebook_runtime_template: gapic_v1.method_async.wrap_method( + self.delete_notebook_runtime_template, + default_timeout=None, + client_info=client_info, + ), + self.assign_notebook_runtime: gapic_v1.method_async.wrap_method( + self.assign_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.get_notebook_runtime: gapic_v1.method_async.wrap_method( + self.get_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.list_notebook_runtimes: gapic_v1.method_async.wrap_method( + self.list_notebook_runtimes, + default_timeout=None, + client_info=client_info, + ), + self.delete_notebook_runtime: gapic_v1.method_async.wrap_method( + self.delete_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.upgrade_notebook_runtime: gapic_v1.method_async.wrap_method( + self.upgrade_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.start_notebook_runtime: gapic_v1.method_async.wrap_method( + self.start_notebook_runtime, + default_timeout=None, + client_info=client_info, + ), + self.get_notebook_execution_job: gapic_v1.method_async.wrap_method( + self.get_notebook_execution_job, + default_timeout=None, + client_info=client_info, + ), + self.list_notebook_execution_jobs: gapic_v1.method_async.wrap_method( + self.list_notebook_execution_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_notebook_execution_job: gapic_v1.method_async.wrap_method( + self.delete_notebook_execution_job, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/rest.py index 67640d90d9..49e4eb4f5f 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/transports/rest.py @@ -43,6 +43,7 @@ OptionalRetry = Union[retries.Retry, object, None] # type: ignore +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import notebook_service from google.longrunning import operations_pb2 # type: ignore @@ -91,6 +92,14 @@ def post_create_notebook_runtime_template(self, response): logging.log(f"Received response: {response}") return response + def pre_delete_notebook_execution_job(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_delete_notebook_execution_job(self, response): + logging.log(f"Received response: {response}") + return response + def pre_delete_notebook_runtime(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -107,6 +116,14 @@ def post_delete_notebook_runtime_template(self, response): logging.log(f"Received response: {response}") return response + def pre_get_notebook_execution_job(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_notebook_execution_job(self, response): + logging.log(f"Received response: {response}") + return response + def pre_get_notebook_runtime(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -123,6 +140,14 @@ def post_get_notebook_runtime_template(self, response): logging.log(f"Received response: {response}") return response + def pre_list_notebook_execution_jobs(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_notebook_execution_jobs(self, response): + logging.log(f"Received response: {response}") + return response + def pre_list_notebook_runtimes(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -211,6 +236,31 @@ def post_create_notebook_runtime_template( """ return response + def pre_delete_notebook_execution_job( + self, + request: notebook_service.DeleteNotebookExecutionJobRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + notebook_service.DeleteNotebookExecutionJobRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for delete_notebook_execution_job + + Override in a subclass to manipulate the request or metadata + before they are sent to the NotebookService server. + """ + return request, metadata + + def post_delete_notebook_execution_job( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for delete_notebook_execution_job + + Override in a subclass to manipulate the response + after it is returned by the NotebookService server but before + it is returned to user code. + """ + return response + def pre_delete_notebook_runtime( self, request: notebook_service.DeleteNotebookRuntimeRequest, @@ -261,6 +311,31 @@ def post_delete_notebook_runtime_template( """ return response + def pre_get_notebook_execution_job( + self, + request: notebook_service.GetNotebookExecutionJobRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + notebook_service.GetNotebookExecutionJobRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for get_notebook_execution_job + + Override in a subclass to manipulate the request or metadata + before they are sent to the NotebookService server. + """ + return request, metadata + + def post_get_notebook_execution_job( + self, response: notebook_execution_job.NotebookExecutionJob + ) -> notebook_execution_job.NotebookExecutionJob: + """Post-rpc interceptor for get_notebook_execution_job + + Override in a subclass to manipulate the response + after it is returned by the NotebookService server but before + it is returned to user code. + """ + return response + def pre_get_notebook_runtime( self, request: notebook_service.GetNotebookRuntimeRequest, @@ -309,6 +384,31 @@ def post_get_notebook_runtime_template( """ return response + def pre_list_notebook_execution_jobs( + self, + request: notebook_service.ListNotebookExecutionJobsRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[ + notebook_service.ListNotebookExecutionJobsRequest, Sequence[Tuple[str, str]] + ]: + """Pre-rpc interceptor for list_notebook_execution_jobs + + Override in a subclass to manipulate the request or metadata + before they are sent to the NotebookService server. + """ + return request, metadata + + def post_list_notebook_execution_jobs( + self, response: notebook_service.ListNotebookExecutionJobsResponse + ) -> notebook_service.ListNotebookExecutionJobsResponse: + """Post-rpc interceptor for list_notebook_execution_jobs + + Override in a subclass to manipulate the response + after it is returned by the NotebookService server but before + it is returned to user code. + """ + return response + def pre_list_notebook_runtimes( self, request: notebook_service.ListNotebookRuntimesRequest, @@ -952,10 +1052,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1326,10 +1422,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1688,10 +1780,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2066,10 +2154,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2444,10 +2528,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -2798,6 +2878,94 @@ def __call__( resp = self._interceptor.post_create_notebook_runtime_template(resp) return resp + class _DeleteNotebookExecutionJob(NotebookServiceRestStub): + def __hash__(self): + return hash("DeleteNotebookExecutionJob") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: notebook_service.DeleteNotebookExecutionJobRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Call the delete notebook execution + job method over HTTP. + + Args: + request (~.notebook_service.DeleteNotebookExecutionJobRequest): + The request object. Request message for + [NotebookService.DeleteNotebookExecutionJob] + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1beta1/{name=projects/*/locations/*/notebookExecutionJobs/*}", + }, + ] + request, metadata = self._interceptor.pre_delete_notebook_execution_job( + request, metadata + ) + pb_request = notebook_service.DeleteNotebookExecutionJobRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_delete_notebook_execution_job(resp) + return resp + class _DeleteNotebookRuntime(NotebookServiceRestStub): def __hash__(self): return hash("DeleteNotebookRuntime") @@ -2975,6 +3143,95 @@ def __call__( resp = self._interceptor.post_delete_notebook_runtime_template(resp) return resp + class _GetNotebookExecutionJob(NotebookServiceRestStub): + def __hash__(self): + return hash("GetNotebookExecutionJob") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: notebook_service.GetNotebookExecutionJobRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> notebook_execution_job.NotebookExecutionJob: + r"""Call the get notebook execution + job method over HTTP. + + Args: + request (~.notebook_service.GetNotebookExecutionJobRequest): + The request object. Request message for + [NotebookService.GetNotebookExecutionJob] + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.notebook_execution_job.NotebookExecutionJob: + NotebookExecutionJob represents an + instance of a notebook execution. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1beta1/{name=projects/*/locations/*/notebookExecutionJobs/*}", + }, + ] + request, metadata = self._interceptor.pre_get_notebook_execution_job( + request, metadata + ) + pb_request = notebook_service.GetNotebookExecutionJobRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = notebook_execution_job.NotebookExecutionJob() + pb_resp = notebook_execution_job.NotebookExecutionJob.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_notebook_execution_job(resp) + return resp + class _GetNotebookRuntime(NotebookServiceRestStub): def __hash__(self): return hash("GetNotebookRuntime") @@ -3157,6 +3414,95 @@ def __call__( resp = self._interceptor.post_get_notebook_runtime_template(resp) return resp + class _ListNotebookExecutionJobs(NotebookServiceRestStub): + def __hash__(self): + return hash("ListNotebookExecutionJobs") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: notebook_service.ListNotebookExecutionJobsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> notebook_service.ListNotebookExecutionJobsResponse: + r"""Call the list notebook execution + jobs method over HTTP. + + Args: + request (~.notebook_service.ListNotebookExecutionJobsRequest): + The request object. Request message for + [NotebookService.ListNotebookExecutionJobs] + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.notebook_service.ListNotebookExecutionJobsResponse: + Response message for + [NotebookService.CreateNotebookExecutionJob] + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1beta1/{parent=projects/*/locations/*}/notebookExecutionJobs", + }, + ] + request, metadata = self._interceptor.pre_list_notebook_execution_jobs( + request, metadata + ) + pb_request = notebook_service.ListNotebookExecutionJobsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = notebook_service.ListNotebookExecutionJobsResponse() + pb_resp = notebook_service.ListNotebookExecutionJobsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_notebook_execution_jobs(resp) + return resp + class _ListNotebookRuntimes(NotebookServiceRestStub): def __hash__(self): return hash("ListNotebookRuntimes") @@ -3545,6 +3891,16 @@ def create_notebook_runtime_template( # In C++ this would require a dynamic_cast return self._CreateNotebookRuntimeTemplate(self._session, self._host, self._interceptor) # type: ignore + @property + def delete_notebook_execution_job( + self, + ) -> Callable[ + [notebook_service.DeleteNotebookExecutionJobRequest], operations_pb2.Operation + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteNotebookExecutionJob(self._session, self._host, self._interceptor) # type: ignore + @property def delete_notebook_runtime( self, @@ -3566,6 +3922,17 @@ def delete_notebook_runtime_template( # In C++ this would require a dynamic_cast return self._DeleteNotebookRuntimeTemplate(self._session, self._host, self._interceptor) # type: ignore + @property + def get_notebook_execution_job( + self, + ) -> Callable[ + [notebook_service.GetNotebookExecutionJobRequest], + notebook_execution_job.NotebookExecutionJob, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetNotebookExecutionJob(self._session, self._host, self._interceptor) # type: ignore + @property def get_notebook_runtime( self, @@ -3587,6 +3954,17 @@ def get_notebook_runtime_template( # In C++ this would require a dynamic_cast return self._GetNotebookRuntimeTemplate(self._session, self._host, self._interceptor) # type: ignore + @property + def list_notebook_execution_jobs( + self, + ) -> Callable[ + [notebook_service.ListNotebookExecutionJobsRequest], + notebook_service.ListNotebookExecutionJobsResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListNotebookExecutionJobs(self._session, self._host, self._interceptor) # type: ignore + @property def list_notebook_runtimes( self, @@ -4395,10 +4773,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4826,10 +5200,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -5248,10 +5618,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5687,10 +6053,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -6126,10 +6488,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py index 522ed8f534..2203cc9bc8 100644 --- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -234,7 +235,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, PersistentResourceServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + PersistentResourceServiceTransport, + Callable[..., PersistentResourceServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -246,9 +253,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.PersistentResourceServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PersistentResourceServiceTransport,Callable[..., PersistentResourceServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PersistentResourceServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -390,8 +399,8 @@ async def sample_create_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, persistent_resource, persistent_resource_id] ) @@ -401,7 +410,14 @@ async def sample_create_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.CreatePersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.CreatePersistentResourceRequest + ): + request = persistent_resource_service.CreatePersistentResourceRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -414,11 +430,9 @@ async def sample_create_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. @@ -515,8 +529,8 @@ async def sample_get_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -524,7 +538,12 @@ async def sample_get_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.GetPersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.GetPersistentResourceRequest + ): + request = persistent_resource_service.GetPersistentResourceRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -533,11 +552,9 @@ async def sample_get_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. @@ -627,8 +644,8 @@ async def sample_list_persistent_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -636,7 +653,14 @@ async def sample_list_persistent_resources(): "the individual field arguments should be set." ) - request = persistent_resource_service.ListPersistentResourcesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.ListPersistentResourcesRequest + ): + request = persistent_resource_service.ListPersistentResourcesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -645,11 +669,9 @@ async def sample_list_persistent_resources(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_persistent_resources, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_persistent_resources + ] # Certain fields should be provided within the metadata header; # add these here. @@ -758,8 +780,8 @@ async def sample_delete_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -767,7 +789,14 @@ async def sample_delete_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.DeletePersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.DeletePersistentResourceRequest + ): + request = persistent_resource_service.DeletePersistentResourceRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -776,11 +805,9 @@ async def sample_delete_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. @@ -894,8 +921,8 @@ async def sample_update_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([persistent_resource, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -903,7 +930,14 @@ async def sample_update_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.UpdatePersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.UpdatePersistentResourceRequest + ): + request = persistent_resource_service.UpdatePersistentResourceRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -914,11 +948,9 @@ async def sample_update_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1022,8 +1054,8 @@ async def sample_reboot_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1031,7 +1063,14 @@ async def sample_reboot_persistent_resource(): "the individual field arguments should be set." ) - request = persistent_resource_service.RebootPersistentResourceRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, persistent_resource_service.RebootPersistentResourceRequest + ): + request = persistent_resource_service.RebootPersistentResourceRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1040,11 +1079,9 @@ async def sample_reboot_persistent_resource(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.reboot_persistent_resource, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.reboot_persistent_resource + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py index c32c05573a..b781b6fe76 100644 --- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -589,7 +590,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, PersistentResourceServiceTransport]] = None, + transport: Optional[ + Union[ + str, + PersistentResourceServiceTransport, + Callable[..., PersistentResourceServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -601,9 +608,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, PersistentResourceServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PersistentResourceServiceTransport,Callable[..., PersistentResourceServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PersistentResourceServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -717,8 +726,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[PersistentResourceServiceTransport], + Callable[..., PersistentResourceServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., PersistentResourceServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -825,8 +842,8 @@ def sample_create_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, persistent_resource, persistent_resource_id] ) @@ -836,10 +853,8 @@ def sample_create_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.CreatePersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.CreatePersistentResourceRequest ): @@ -956,8 +971,8 @@ def sample_get_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -965,10 +980,8 @@ def sample_get_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.GetPersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.GetPersistentResourceRequest ): @@ -1070,8 +1083,8 @@ def sample_list_persistent_resources(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1079,10 +1092,8 @@ def sample_list_persistent_resources(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.ListPersistentResourcesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.ListPersistentResourcesRequest ): @@ -1207,8 +1218,8 @@ def sample_delete_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1216,10 +1227,8 @@ def sample_delete_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.DeletePersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.DeletePersistentResourceRequest ): @@ -1349,8 +1358,8 @@ def sample_update_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([persistent_resource, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1358,10 +1367,8 @@ def sample_update_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.UpdatePersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.UpdatePersistentResourceRequest ): @@ -1483,8 +1490,8 @@ def sample_reboot_persistent_resource(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1492,10 +1499,8 @@ def sample_reboot_persistent_resource(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a persistent_resource_service.RebootPersistentResourceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, persistent_resource_service.RebootPersistentResourceRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc.py index 9b6bdd72a0..c063fcb2c0 100644 --- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc.py @@ -57,7 +57,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -77,14 +77,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -94,11 +97,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -125,7 +128,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -166,7 +169,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc_asyncio.py index 4c0d8706b3..3b0496a0b1 100644 --- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -72,7 +74,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -102,7 +103,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -122,15 +123,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -140,11 +144,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -171,7 +175,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -211,7 +215,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -430,6 +436,41 @@ def reboot_persistent_resource( ) return self._stubs["reboot_persistent_resource"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_persistent_resource: gapic_v1.method_async.wrap_method( + self.create_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + self.get_persistent_resource: gapic_v1.method_async.wrap_method( + self.get_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + self.list_persistent_resources: gapic_v1.method_async.wrap_method( + self.list_persistent_resources, + default_timeout=None, + client_info=client_info, + ), + self.delete_persistent_resource: gapic_v1.method_async.wrap_method( + self.delete_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + self.update_persistent_resource: gapic_v1.method_async.wrap_method( + self.update_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + self.reboot_persistent_resource: gapic_v1.method_async.wrap_method( + self.reboot_persistent_resource, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/rest.py index 975a8db392..c049d5113d 100644 --- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/transports/rest.py @@ -832,10 +832,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1206,10 +1202,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1568,10 +1560,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1946,10 +1934,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2324,10 +2308,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -3889,10 +3869,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4320,10 +4296,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4742,10 +4714,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5181,10 +5149,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5620,10 +5584,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 173c45a4c0..4925666eec 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -237,7 +238,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, PipelineServiceTransport, Callable[..., PipelineServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -249,9 +254,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.PipelineServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PipelineServiceTransport,Callable[..., PipelineServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PipelineServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -381,8 +388,8 @@ async def sample_create_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: raise ValueError( @@ -390,7 +397,10 @@ async def sample_create_training_pipeline(): "the individual field arguments should be set." ) - request = pipeline_service.CreateTrainingPipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.CreateTrainingPipelineRequest): + request = pipeline_service.CreateTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -401,11 +411,9 @@ async def sample_create_training_pipeline(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_training_pipeline, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_training_pipeline + ] # Certain fields should be provided within the metadata header; # add these here. @@ -495,8 +503,8 @@ async def sample_get_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -504,7 +512,10 @@ async def sample_get_training_pipeline(): "the individual field arguments should be set." ) - request = pipeline_service.GetTrainingPipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.GetTrainingPipelineRequest): + request = pipeline_service.GetTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -513,11 +524,9 @@ async def sample_get_training_pipeline(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_training_pipeline, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_training_pipeline + ] # Certain fields should be provided within the metadata header; # add these here. @@ -607,8 +616,8 @@ async def sample_list_training_pipelines(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -616,7 +625,10 @@ async def sample_list_training_pipelines(): "the individual field arguments should be set." ) - request = pipeline_service.ListTrainingPipelinesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.ListTrainingPipelinesRequest): + request = pipeline_service.ListTrainingPipelinesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -625,11 +637,9 @@ async def sample_list_training_pipelines(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_training_pipelines, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_training_pipelines + ] # Certain fields should be provided within the metadata header; # add these here. @@ -738,8 +748,8 @@ async def sample_delete_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -747,7 +757,10 @@ async def sample_delete_training_pipeline(): "the individual field arguments should be set." ) - request = pipeline_service.DeleteTrainingPipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.DeleteTrainingPipelineRequest): + request = pipeline_service.DeleteTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -756,11 +769,9 @@ async def sample_delete_training_pipeline(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_training_pipeline, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_training_pipeline + ] # Certain fields should be provided within the metadata header; # add these here. @@ -857,8 +868,8 @@ async def sample_cancel_training_pipeline(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -866,7 +877,10 @@ async def sample_cancel_training_pipeline(): "the individual field arguments should be set." ) - request = pipeline_service.CancelTrainingPipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.CancelTrainingPipelineRequest): + request = pipeline_service.CancelTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -875,11 +889,9 @@ async def sample_cancel_training_pipeline(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_training_pipeline, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_training_pipeline + ] # Certain fields should be provided within the metadata header; # add these here. @@ -981,8 +993,8 @@ async def sample_create_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) if request is not None and has_flattened_params: raise ValueError( @@ -990,7 +1002,10 @@ async def sample_create_pipeline_job(): "the individual field arguments should be set." ) - request = pipeline_service.CreatePipelineJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.CreatePipelineJobRequest): + request = pipeline_service.CreatePipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1003,11 +1018,9 @@ async def sample_create_pipeline_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_pipeline_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_pipeline_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1090,8 +1103,8 @@ async def sample_get_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1099,7 +1112,10 @@ async def sample_get_pipeline_job(): "the individual field arguments should be set." ) - request = pipeline_service.GetPipelineJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.GetPipelineJobRequest): + request = pipeline_service.GetPipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1108,11 +1124,9 @@ async def sample_get_pipeline_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_pipeline_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_pipeline_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1200,8 +1214,8 @@ async def sample_list_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1209,7 +1223,10 @@ async def sample_list_pipeline_jobs(): "the individual field arguments should be set." ) - request = pipeline_service.ListPipelineJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.ListPipelineJobsRequest): + request = pipeline_service.ListPipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1218,11 +1235,9 @@ async def sample_list_pipeline_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_pipeline_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_pipeline_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1331,8 +1346,8 @@ async def sample_delete_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1340,7 +1355,10 @@ async def sample_delete_pipeline_job(): "the individual field arguments should be set." ) - request = pipeline_service.DeletePipelineJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.DeletePipelineJobRequest): + request = pipeline_service.DeletePipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1349,11 +1367,9 @@ async def sample_delete_pipeline_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_pipeline_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_pipeline_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1466,8 +1482,8 @@ async def sample_batch_delete_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, names]) if request is not None and has_flattened_params: raise ValueError( @@ -1475,7 +1491,10 @@ async def sample_batch_delete_pipeline_jobs(): "the individual field arguments should be set." ) - request = pipeline_service.BatchDeletePipelineJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.BatchDeletePipelineJobsRequest): + request = pipeline_service.BatchDeletePipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1486,11 +1505,9 @@ async def sample_batch_delete_pipeline_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_delete_pipeline_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_delete_pipeline_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1586,8 +1603,8 @@ async def sample_cancel_pipeline_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1595,7 +1612,10 @@ async def sample_cancel_pipeline_job(): "the individual field arguments should be set." ) - request = pipeline_service.CancelPipelineJobRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.CancelPipelineJobRequest): + request = pipeline_service.CancelPipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1604,11 +1624,9 @@ async def sample_cancel_pipeline_job(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_pipeline_job, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.cancel_pipeline_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1715,8 +1733,8 @@ async def sample_batch_cancel_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, names]) if request is not None and has_flattened_params: raise ValueError( @@ -1724,7 +1742,10 @@ async def sample_batch_cancel_pipeline_jobs(): "the individual field arguments should be set." ) - request = pipeline_service.BatchCancelPipelineJobsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, pipeline_service.BatchCancelPipelineJobsRequest): + request = pipeline_service.BatchCancelPipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1735,11 +1756,9 @@ async def sample_batch_cancel_pipeline_jobs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_cancel_pipeline_jobs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_cancel_pipeline_jobs + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index e21d0f64e3..61a70dd872 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -731,7 +732,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, PipelineServiceTransport]] = None, + transport: Optional[ + Union[ + str, PipelineServiceTransport, Callable[..., PipelineServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -743,9 +748,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, PipelineServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PipelineServiceTransport,Callable[..., PipelineServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PipelineServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -857,8 +864,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[PipelineServiceTransport], Callable[..., PipelineServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., PipelineServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -953,8 +967,8 @@ def sample_create_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: raise ValueError( @@ -962,10 +976,8 @@ def sample_create_training_pipeline(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CreateTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.CreateTrainingPipelineRequest): request = pipeline_service.CreateTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1067,8 +1079,8 @@ def sample_get_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1076,10 +1088,8 @@ def sample_get_training_pipeline(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.GetTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.GetTrainingPipelineRequest): request = pipeline_service.GetTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1179,8 +1189,8 @@ def sample_list_training_pipelines(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1188,10 +1198,8 @@ def sample_list_training_pipelines(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.ListTrainingPipelinesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.ListTrainingPipelinesRequest): request = pipeline_service.ListTrainingPipelinesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1310,8 +1318,8 @@ def sample_delete_training_pipeline(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1319,10 +1327,8 @@ def sample_delete_training_pipeline(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.DeleteTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.DeleteTrainingPipelineRequest): request = pipeline_service.DeleteTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1429,8 +1435,8 @@ def sample_cancel_training_pipeline(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1438,10 +1444,8 @@ def sample_cancel_training_pipeline(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CancelTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.CancelTrainingPipelineRequest): request = pipeline_service.CancelTrainingPipelineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1553,8 +1557,8 @@ def sample_create_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, pipeline_job, pipeline_job_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1562,10 +1566,8 @@ def sample_create_pipeline_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CreatePipelineJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.CreatePipelineJobRequest): request = pipeline_service.CreatePipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1662,8 +1664,8 @@ def sample_get_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1671,10 +1673,8 @@ def sample_get_pipeline_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.GetPipelineJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.GetPipelineJobRequest): request = pipeline_service.GetPipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1772,8 +1772,8 @@ def sample_list_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1781,10 +1781,8 @@ def sample_list_pipeline_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.ListPipelineJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.ListPipelineJobsRequest): request = pipeline_service.ListPipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1903,8 +1901,8 @@ def sample_delete_pipeline_job(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1912,10 +1910,8 @@ def sample_delete_pipeline_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.DeletePipelineJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.DeletePipelineJobRequest): request = pipeline_service.DeletePipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2038,8 +2034,8 @@ def sample_batch_delete_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, names]) if request is not None and has_flattened_params: raise ValueError( @@ -2047,10 +2043,8 @@ def sample_batch_delete_pipeline_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.BatchDeletePipelineJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.BatchDeletePipelineJobsRequest): request = pipeline_service.BatchDeletePipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2160,8 +2154,8 @@ def sample_cancel_pipeline_job(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2169,10 +2163,8 @@ def sample_cancel_pipeline_job(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CancelPipelineJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.CancelPipelineJobRequest): request = pipeline_service.CancelPipelineJobRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2289,8 +2281,8 @@ def sample_batch_cancel_pipeline_jobs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, names]) if request is not None and has_flattened_params: raise ValueError( @@ -2298,10 +2290,8 @@ def sample_batch_cancel_pipeline_jobs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.BatchCancelPipelineJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, pipeline_service.BatchCancelPipelineJobsRequest): request = pipeline_service.BatchCancelPipelineJobsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index c6fe96522d..77781369c8 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -65,7 +65,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -85,14 +85,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -102,11 +105,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -133,7 +136,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -174,7 +177,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index 99c4c09e7b..cb16a4f555 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -80,7 +82,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -110,7 +111,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -130,15 +131,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -148,11 +152,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -179,7 +183,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -219,7 +223,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -645,6 +651,71 @@ def batch_cancel_pipeline_jobs( ) return self._stubs["batch_cancel_pipeline_jobs"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_training_pipeline: gapic_v1.method_async.wrap_method( + self.create_training_pipeline, + default_timeout=5.0, + client_info=client_info, + ), + self.get_training_pipeline: gapic_v1.method_async.wrap_method( + self.get_training_pipeline, + default_timeout=5.0, + client_info=client_info, + ), + self.list_training_pipelines: gapic_v1.method_async.wrap_method( + self.list_training_pipelines, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_training_pipeline: gapic_v1.method_async.wrap_method( + self.delete_training_pipeline, + default_timeout=5.0, + client_info=client_info, + ), + self.cancel_training_pipeline: gapic_v1.method_async.wrap_method( + self.cancel_training_pipeline, + default_timeout=5.0, + client_info=client_info, + ), + self.create_pipeline_job: gapic_v1.method_async.wrap_method( + self.create_pipeline_job, + default_timeout=None, + client_info=client_info, + ), + self.get_pipeline_job: gapic_v1.method_async.wrap_method( + self.get_pipeline_job, + default_timeout=None, + client_info=client_info, + ), + self.list_pipeline_jobs: gapic_v1.method_async.wrap_method( + self.list_pipeline_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_pipeline_job: gapic_v1.method_async.wrap_method( + self.delete_pipeline_job, + default_timeout=None, + client_info=client_info, + ), + self.batch_delete_pipeline_jobs: gapic_v1.method_async.wrap_method( + self.batch_delete_pipeline_jobs, + default_timeout=None, + client_info=client_info, + ), + self.cancel_pipeline_job: gapic_v1.method_async.wrap_method( + self.cancel_pipeline_job, + default_timeout=None, + client_info=client_info, + ), + self.batch_cancel_pipeline_jobs: gapic_v1.method_async.wrap_method( + self.batch_cancel_pipeline_jobs, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/rest.py index 461505401e..481daba9dd 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/rest.py @@ -990,10 +990,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1364,10 +1360,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1726,10 +1718,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2104,10 +2092,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2482,10 +2466,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4605,10 +4585,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -5036,10 +5012,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -5458,10 +5430,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5897,10 +5865,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -6336,10 +6300,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index 28011cc656..cb63c4736f 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -209,7 +210,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + PredictionServiceTransport, + Callable[..., PredictionServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -221,9 +228,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.PredictionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PredictionServiceTransport,Callable[..., PredictionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PredictionServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -365,8 +374,8 @@ async def sample_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: raise ValueError( @@ -374,7 +383,10 @@ async def sample_predict(): "the individual field arguments should be set." ) - request = prediction_service.PredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.PredictRequest): + request = prediction_service.PredictRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -387,11 +399,7 @@ async def sample_predict(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.predict, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[self._client._transport.predict] # Certain fields should be provided within the metadata header; # add these here. @@ -556,8 +564,8 @@ async def sample_raw_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: raise ValueError( @@ -565,7 +573,10 @@ async def sample_raw_predict(): "the individual field arguments should be set." ) - request = prediction_service.RawPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.RawPredictRequest): + request = prediction_service.RawPredictRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -576,11 +587,9 @@ async def sample_raw_predict(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.raw_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -657,15 +666,16 @@ async def sample_direct_predict(): """ # Create or coerce a protobuf request object. - request = prediction_service.DirectPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.DirectPredictRequest): + request = prediction_service.DirectPredictRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.direct_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.direct_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -743,15 +753,16 @@ async def sample_direct_raw_predict(): """ # Create or coerce a protobuf request object. - request = prediction_service.DirectRawPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.DirectRawPredictRequest): + request = prediction_service.DirectRawPredictRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.direct_raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.direct_raw_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -848,11 +859,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stream_direct_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stream_direct_predict + ] # Validate the universe domain. self._client._validate_universe_domain() @@ -948,11 +957,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stream_direct_raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stream_direct_raw_predict + ] # Validate the universe domain. self._client._validate_universe_domain() @@ -1042,11 +1049,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.streaming_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.streaming_predict + ] # Validate the universe domain. self._client._validate_universe_domain() @@ -1124,15 +1129,16 @@ async def sample_server_streaming_predict(): """ # Create or coerce a protobuf request object. - request = prediction_service.StreamingPredictRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.StreamingPredictRequest): + request = prediction_service.StreamingPredictRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.server_streaming_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.server_streaming_predict + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1234,11 +1240,9 @@ def request_generator(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.streaming_raw_predict, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.streaming_raw_predict + ] # Validate the universe domain. self._client._validate_universe_domain() @@ -1368,8 +1372,8 @@ async def sample_explain(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1377,7 +1381,10 @@ async def sample_explain(): "the individual field arguments should be set." ) - request = prediction_service.ExplainRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.ExplainRequest): + request = prediction_service.ExplainRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1392,11 +1399,7 @@ async def sample_explain(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.explain, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[self._client._transport.explain] # Certain fields should be provided within the metadata header; # add these here. @@ -1499,8 +1502,8 @@ async def sample_count_tokens(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances]) if request is not None and has_flattened_params: raise ValueError( @@ -1508,7 +1511,10 @@ async def sample_count_tokens(): "the individual field arguments should be set." ) - request = prediction_service.CountTokensRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.CountTokensRequest): + request = prediction_service.CountTokensRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1519,11 +1525,9 @@ async def sample_count_tokens(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.count_tokens, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.count_tokens + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1624,8 +1628,8 @@ async def sample_generate_content(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -1633,7 +1637,10 @@ async def sample_generate_content(): "the individual field arguments should be set." ) - request = prediction_service.GenerateContentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.GenerateContentRequest): + request = prediction_service.GenerateContentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1644,11 +1651,9 @@ async def sample_generate_content(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.generate_content, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.generate_content + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1751,8 +1756,8 @@ async def sample_stream_generate_content(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -1760,7 +1765,10 @@ async def sample_stream_generate_content(): "the individual field arguments should be set." ) - request = prediction_service.GenerateContentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.GenerateContentRequest): + request = prediction_service.GenerateContentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1771,11 +1779,9 @@ async def sample_stream_generate_content(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stream_generate_content, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stream_generate_content + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1797,170 +1803,6 @@ async def sample_stream_generate_content(): # Done; return the response. return response - def chat_completions( - self, - request: Optional[ - Union[prediction_service.ChatCompletionsRequest, dict] - ] = None, - *, - endpoint: Optional[str] = None, - http_body: Optional[httpbody_pb2.HttpBody] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[httpbody_pb2.HttpBody]]: - r"""Exposes an OpenAI-compatible endpoint for chat - completions. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.cloud import aiplatform_v1beta1 - - async def sample_chat_completions(): - # Create a client - client = aiplatform_v1beta1.PredictionServiceAsyncClient() - - # Initialize request argument(s) - request = aiplatform_v1beta1.ChatCompletionsRequest( - endpoint="endpoint_value", - ) - - # Make the request - stream = await client.chat_completions(request=request) - - # Handle the response - async for response in stream: - print(response) - - Args: - request (Optional[Union[google.cloud.aiplatform_v1beta1.types.ChatCompletionsRequest, dict]]): - The request object. Request message for [PredictionService.ChatCompletions] - endpoint (:class:`str`): - Required. The name of the Endpoint requested to serve - the prediction. Format: - ``projects/{project}/locations/{location}/endpoints/openapi`` - - This corresponds to the ``endpoint`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - http_body (:class:`google.api.httpbody_pb2.HttpBody`): - Optional. The prediction input. - Supports HTTP headers and arbitrary data - payload. - - This corresponds to the ``http_body`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - AsyncIterable[google.api.httpbody_pb2.HttpBody]: - Message that represents an arbitrary HTTP body. It should only be used for - payload formats that can't be represented as JSON, - such as raw binary or an HTML page. - - This message can be used both in streaming and - non-streaming API methods in the request as well as - the response. - - It can be used as a top-level request field, which is - convenient if one wants to extract parameters from - either the URL or HTTP template into the request - fields and also want access to the raw HTTP body. - - Example: - - message GetResourceRequest { - // A unique request id. string request_id = 1; - - // The raw HTTP body is bound to this field. - google.api.HttpBody http_body = 2; - - } - - service ResourceService { - rpc GetResource(GetResourceRequest) - returns (google.api.HttpBody); - - rpc UpdateResource(google.api.HttpBody) - returns (google.protobuf.Empty); - - } - - Example with streaming methods: - - service CaldavService { - rpc GetCalendar(stream google.api.HttpBody) - returns (stream google.api.HttpBody); - - rpc UpdateCalendar(stream google.api.HttpBody) - returns (stream google.api.HttpBody); - - } - - Use of this type only changes how the request and - response bodies are handled, all other features will - continue to work unchanged. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, http_body]) - if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = prediction_service.ChatCompletionsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if http_body is not None: - request.http_body = http_body - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.chat_completions, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), - ) - - # Validate the universe domain. - self._client._validate_universe_domain() - - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - async def list_operations( self, request: Optional[operations_pb2.ListOperationsRequest] = None, diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 5dfa236d01..fa197c4d4d 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -583,7 +584,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, PredictionServiceTransport]] = None, + transport: Optional[ + Union[ + str, + PredictionServiceTransport, + Callable[..., PredictionServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -595,9 +602,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, PredictionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,PredictionServiceTransport,Callable[..., PredictionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PredictionServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -709,8 +718,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[PredictionServiceTransport], + Callable[..., PredictionServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., PredictionServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -817,8 +834,8 @@ def sample_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: raise ValueError( @@ -826,10 +843,8 @@ def sample_predict(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.PredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.PredictRequest): request = prediction_service.PredictRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1008,8 +1023,8 @@ def sample_raw_predict(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, http_body]) if request is not None and has_flattened_params: raise ValueError( @@ -1017,10 +1032,8 @@ def sample_raw_predict(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.RawPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.RawPredictRequest): request = prediction_service.RawPredictRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1109,10 +1122,8 @@ def sample_direct_predict(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.DirectPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.DirectPredictRequest): request = prediction_service.DirectPredictRequest(request) @@ -1196,10 +1207,8 @@ def sample_direct_raw_predict(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.DirectRawPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.DirectRawPredictRequest): request = prediction_service.DirectRawPredictRequest(request) @@ -1566,10 +1575,8 @@ def sample_server_streaming_predict(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.StreamingPredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.StreamingPredictRequest): request = prediction_service.StreamingPredictRequest(request) @@ -1807,8 +1814,8 @@ def sample_explain(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1816,10 +1823,8 @@ def sample_explain(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.ExplainRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.ExplainRequest): request = prediction_service.ExplainRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1938,8 +1943,8 @@ def sample_count_tokens(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances]) if request is not None and has_flattened_params: raise ValueError( @@ -1947,10 +1952,8 @@ def sample_count_tokens(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.CountTokensRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.CountTokensRequest): request = prediction_service.CountTokensRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2063,8 +2066,8 @@ def sample_generate_content(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -2072,10 +2075,8 @@ def sample_generate_content(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.GenerateContentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.GenerateContentRequest): request = prediction_service.GenerateContentRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2190,8 +2191,8 @@ def sample_stream_generate_content(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([model, contents]) if request is not None and has_flattened_params: raise ValueError( @@ -2199,10 +2200,8 @@ def sample_stream_generate_content(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.GenerateContentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, prediction_service.GenerateContentRequest): request = prediction_service.GenerateContentRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2236,170 +2235,6 @@ def sample_stream_generate_content(): # Done; return the response. return response - def chat_completions( - self, - request: Optional[ - Union[prediction_service.ChatCompletionsRequest, dict] - ] = None, - *, - endpoint: Optional[str] = None, - http_body: Optional[httpbody_pb2.HttpBody] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> Iterable[httpbody_pb2.HttpBody]: - r"""Exposes an OpenAI-compatible endpoint for chat - completions. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.cloud import aiplatform_v1beta1 - - def sample_chat_completions(): - # Create a client - client = aiplatform_v1beta1.PredictionServiceClient() - - # Initialize request argument(s) - request = aiplatform_v1beta1.ChatCompletionsRequest( - endpoint="endpoint_value", - ) - - # Make the request - stream = client.chat_completions(request=request) - - # Handle the response - for response in stream: - print(response) - - Args: - request (Union[google.cloud.aiplatform_v1beta1.types.ChatCompletionsRequest, dict]): - The request object. Request message for [PredictionService.ChatCompletions] - endpoint (str): - Required. The name of the Endpoint requested to serve - the prediction. Format: - ``projects/{project}/locations/{location}/endpoints/openapi`` - - This corresponds to the ``endpoint`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - http_body (google.api.httpbody_pb2.HttpBody): - Optional. The prediction input. - Supports HTTP headers and arbitrary data - payload. - - This corresponds to the ``http_body`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - Iterable[google.api.httpbody_pb2.HttpBody]: - Message that represents an arbitrary HTTP body. It should only be used for - payload formats that can't be represented as JSON, - such as raw binary or an HTML page. - - This message can be used both in streaming and - non-streaming API methods in the request as well as - the response. - - It can be used as a top-level request field, which is - convenient if one wants to extract parameters from - either the URL or HTTP template into the request - fields and also want access to the raw HTTP body. - - Example: - - message GetResourceRequest { - // A unique request id. string request_id = 1; - - // The raw HTTP body is bound to this field. - google.api.HttpBody http_body = 2; - - } - - service ResourceService { - rpc GetResource(GetResourceRequest) - returns (google.api.HttpBody); - - rpc UpdateResource(google.api.HttpBody) - returns (google.protobuf.Empty); - - } - - Example with streaming methods: - - service CaldavService { - rpc GetCalendar(stream google.api.HttpBody) - returns (stream google.api.HttpBody); - - rpc UpdateCalendar(stream google.api.HttpBody) - returns (stream google.api.HttpBody); - - } - - Use of this type only changes how the request and - response bodies are handled, all other features will - continue to work unchanged. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, http_body]) - if request is not None and has_flattened_params: - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.ChatCompletionsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, prediction_service.ChatCompletionsRequest): - request = prediction_service.ChatCompletionsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if http_body is not None: - request.http_body = http_body - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.chat_completions] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), - ) - - # Validate the universe domain. - self._validate_universe_domain() - - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - def __enter__(self) -> "PredictionServiceClient": return self diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index 44fb32635a..2979e418c3 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -197,11 +197,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - self.chat_completions: gapic_v1.method.wrap_method( - self.chat_completions, - default_timeout=None, - client_info=client_info, - ), } def close(self): @@ -366,15 +361,6 @@ def stream_generate_content( ]: raise NotImplementedError() - @property - def chat_completions( - self, - ) -> Callable[ - [prediction_service.ChatCompletionsRequest], - Union[httpbody_pb2.HttpBody, Awaitable[httpbody_pb2.HttpBody]], - ]: - raise NotImplementedError() - @property def list_operations( self, diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index e40bf8e2b7..1746269575 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -55,7 +55,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -75,14 +75,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -92,11 +95,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -122,7 +125,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -163,7 +166,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -635,33 +640,6 @@ def stream_generate_content( ) return self._stubs["stream_generate_content"] - @property - def chat_completions( - self, - ) -> Callable[[prediction_service.ChatCompletionsRequest], httpbody_pb2.HttpBody]: - r"""Return a callable for the chat completions method over gRPC. - - Exposes an OpenAI-compatible endpoint for chat - completions. - - Returns: - Callable[[~.ChatCompletionsRequest], - ~.HttpBody]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "chat_completions" not in self._stubs: - self._stubs["chat_completions"] = self.grpc_channel.unary_stream( - "/google.cloud.aiplatform.v1beta1.PredictionService/ChatCompletions", - request_serializer=prediction_service.ChatCompletionsRequest.serialize, - response_deserializer=httpbody_pb2.HttpBody.FromString, - ) - return self._stubs["chat_completions"] - def close(self): self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index 27e0595610..db1c499eaf 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -70,7 +72,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -100,7 +101,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -120,15 +121,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -138,11 +142,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -168,7 +172,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -208,7 +212,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -643,34 +649,75 @@ def stream_generate_content( ) return self._stubs["stream_generate_content"] - @property - def chat_completions( - self, - ) -> Callable[ - [prediction_service.ChatCompletionsRequest], Awaitable[httpbody_pb2.HttpBody] - ]: - r"""Return a callable for the chat completions method over gRPC. - - Exposes an OpenAI-compatible endpoint for chat - completions. - - Returns: - Callable[[~.ChatCompletionsRequest], - Awaitable[~.HttpBody]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "chat_completions" not in self._stubs: - self._stubs["chat_completions"] = self.grpc_channel.unary_stream( - "/google.cloud.aiplatform.v1beta1.PredictionService/ChatCompletions", - request_serializer=prediction_service.ChatCompletionsRequest.serialize, - response_deserializer=httpbody_pb2.HttpBody.FromString, - ) - return self._stubs["chat_completions"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.predict: gapic_v1.method_async.wrap_method( + self.predict, + default_timeout=5.0, + client_info=client_info, + ), + self.raw_predict: gapic_v1.method_async.wrap_method( + self.raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.direct_predict: gapic_v1.method_async.wrap_method( + self.direct_predict, + default_timeout=None, + client_info=client_info, + ), + self.direct_raw_predict: gapic_v1.method_async.wrap_method( + self.direct_raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.stream_direct_predict: gapic_v1.method_async.wrap_method( + self.stream_direct_predict, + default_timeout=None, + client_info=client_info, + ), + self.stream_direct_raw_predict: gapic_v1.method_async.wrap_method( + self.stream_direct_raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.streaming_predict: gapic_v1.method_async.wrap_method( + self.streaming_predict, + default_timeout=None, + client_info=client_info, + ), + self.server_streaming_predict: gapic_v1.method_async.wrap_method( + self.server_streaming_predict, + default_timeout=None, + client_info=client_info, + ), + self.streaming_raw_predict: gapic_v1.method_async.wrap_method( + self.streaming_raw_predict, + default_timeout=None, + client_info=client_info, + ), + self.explain: gapic_v1.method_async.wrap_method( + self.explain, + default_timeout=5.0, + client_info=client_info, + ), + self.count_tokens: gapic_v1.method_async.wrap_method( + self.count_tokens, + default_timeout=None, + client_info=client_info, + ), + self.generate_content: gapic_v1.method_async.wrap_method( + self.generate_content, + default_timeout=None, + client_info=client_info, + ), + self.stream_generate_content: gapic_v1.method_async.wrap_method( + self.stream_generate_content, + default_timeout=None, + client_info=client_info, + ), + } def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/rest.py index 7f907099ba..2da7be1d7c 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/rest.py @@ -74,14 +74,6 @@ class PredictionServiceRestInterceptor: .. code-block:: python class MyCustomPredictionServiceInterceptor(PredictionServiceRestInterceptor): - def pre_chat_completions(self, request, metadata): - logging.log(f"Received request: {request}") - return request, metadata - - def post_chat_completions(self, response): - logging.log(f"Received response: {response}") - return response - def pre_count_tokens(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -160,29 +152,6 @@ def post_stream_generate_content(self, response): """ - def pre_chat_completions( - self, - request: prediction_service.ChatCompletionsRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[prediction_service.ChatCompletionsRequest, Sequence[Tuple[str, str]]]: - """Pre-rpc interceptor for chat_completions - - Override in a subclass to manipulate the request or metadata - before they are sent to the PredictionService server. - """ - return request, metadata - - def post_chat_completions( - self, response: rest_streaming.ResponseIterator - ) -> rest_streaming.ResponseIterator: - """Post-rpc interceptor for chat_completions - - Override in a subclass to manipulate the response - after it is returned by the PredictionService server but before - it is returned to user code. - """ - return response - def pre_count_tokens( self, request: prediction_service.CountTokensRequest, @@ -716,144 +685,6 @@ def __init__( self._interceptor = interceptor or PredictionServiceRestInterceptor() self._prep_wrapped_messages(client_info) - class _ChatCompletions(PredictionServiceRestStub): - def __hash__(self): - return hash("ChatCompletions") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } - - def __call__( - self, - request: prediction_service.ChatCompletionsRequest, - *, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> rest_streaming.ResponseIterator: - r"""Call the chat completions method over HTTP. - - Args: - request (~.prediction_service.ChatCompletionsRequest): - The request object. Request message for [PredictionService.ChatCompletions] - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.httpbody_pb2.HttpBody: - Message that represents an arbitrary HTTP body. It - should only be used for payload formats that can't be - represented as JSON, such as raw binary or an HTML page. - - This message can be used both in streaming and - non-streaming API methods in the request as well as the - response. - - It can be used as a top-level request field, which is - convenient if one wants to extract parameters from - either the URL or HTTP template into the request fields - and also want access to the raw HTTP body. - - Example: - - :: - - message GetResourceRequest { - // A unique request id. - string request_id = 1; - - // The raw HTTP body is bound to this field. - google.api.HttpBody http_body = 2; - - } - - service ResourceService { - rpc GetResource(GetResourceRequest) - returns (google.api.HttpBody); - rpc UpdateResource(google.api.HttpBody) - returns (google.protobuf.Empty); - - } - - Example with streaming methods: - - :: - - service CaldavService { - rpc GetCalendar(stream google.api.HttpBody) - returns (stream google.api.HttpBody); - rpc UpdateCalendar(stream google.api.HttpBody) - returns (stream google.api.HttpBody); - - } - - Use of this type only changes how the request and - response bodies are handled, all other features will - continue to work unchanged. - - """ - - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1beta1/{endpoint=projects/*/locations/*/endpoints/*}/chat/completions", - "body": "http_body", - }, - ] - request, metadata = self._interceptor.pre_chat_completions( - request, metadata - ) - pb_request = prediction_service.ChatCompletionsRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body - - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False - ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] - - # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=False, - ) - ) - query_params.update(self._get_unset_required_fields(query_params)) - - # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, - ) - - # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception - # subclass. - if response.status_code >= 400: - raise core_exceptions.from_http_response(response) - - # Return the response - resp = rest_streaming.ResponseIterator(response, httpbody_pb2.HttpBody) - resp = self._interceptor.post_chat_completions(resp) - return resp - class _CountTokens(PredictionServiceRestStub): def __hash__(self): return hash("CountTokens") @@ -1841,14 +1672,6 @@ def __call__( "Method StreamingRawPredict is not available over REST transport" ) - @property - def chat_completions( - self, - ) -> Callable[[prediction_service.ChatCompletionsRequest], httpbody_pb2.HttpBody]: - # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. - # In C++ this would require a dynamic_cast - return self._ChatCompletions(self._session, self._host, self._interceptor) # type: ignore - @property def count_tokens( self, @@ -2752,10 +2575,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -3183,10 +3002,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -3605,10 +3420,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -4044,10 +3855,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -4483,10 +4290,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/async_client.py index 2f6f2b6b0c..573bdee799 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -213,8 +214,12 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[ - str, ReasoningEngineExecutionServiceTransport + transport: Optional[ + Union[ + str, + ReasoningEngineExecutionServiceTransport, + Callable[..., ReasoningEngineExecutionServiceTransport], + ] ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, @@ -227,9 +232,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ReasoningEngineExecutionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ReasoningEngineExecutionServiceTransport,Callable[..., ReasoningEngineExecutionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ReasoningEngineExecutionServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -331,17 +338,20 @@ async def sample_query_reasoning_engine(): """ # Create or coerce a protobuf request object. - request = reasoning_engine_execution_service.QueryReasoningEngineRequest( - request - ) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, reasoning_engine_execution_service.QueryReasoningEngineRequest + ): + request = reasoning_engine_execution_service.QueryReasoningEngineRequest( + request + ) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.query_reasoning_engine, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_reasoning_engine + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py index c4d9067e92..bc9b7fd1fa 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -542,7 +543,11 @@ def __init__( *, credentials: Optional[ga_credentials.Credentials] = None, transport: Optional[ - Union[str, ReasoningEngineExecutionServiceTransport] + Union[ + str, + ReasoningEngineExecutionServiceTransport, + Callable[..., ReasoningEngineExecutionServiceTransport], + ] ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, @@ -555,9 +560,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ReasoningEngineExecutionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ReasoningEngineExecutionServiceTransport,Callable[..., ReasoningEngineExecutionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ReasoningEngineExecutionServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -675,8 +682,18 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ReasoningEngineExecutionServiceTransport], + Callable[..., ReasoningEngineExecutionServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast( + Callable[..., ReasoningEngineExecutionServiceTransport], transport + ) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -743,10 +760,8 @@ def sample_query_reasoning_engine(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a reasoning_engine_execution_service.QueryReasoningEngineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, reasoning_engine_execution_service.QueryReasoningEngineRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py index a284f9190c..7bb838ae18 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py @@ -56,7 +56,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -76,14 +76,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -93,11 +96,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -123,7 +126,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -164,7 +167,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc_asyncio.py index 59f76f4b14..1c7b740102 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -71,7 +73,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -101,7 +102,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -121,15 +122,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -139,11 +143,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -169,7 +173,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -209,7 +213,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -267,6 +273,16 @@ def query_reasoning_engine( ) return self._stubs["query_reasoning_engine"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.query_reasoning_engine: gapic_v1.method_async.wrap_method( + self.query_reasoning_engine, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py index 333c49849e..6e70d6d995 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py @@ -1321,10 +1321,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1752,10 +1748,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2174,10 +2166,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2613,10 +2601,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3052,10 +3036,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/async_client.py index 6d9be4bcf9..805f06a5ee 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -213,7 +214,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ReasoningEngineServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + ReasoningEngineServiceTransport, + Callable[..., ReasoningEngineServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -225,9 +232,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ReasoningEngineServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ReasoningEngineServiceTransport,Callable[..., ReasoningEngineServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ReasoningEngineServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -356,8 +365,8 @@ async def sample_create_reasoning_engine(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, reasoning_engine]) if request is not None and has_flattened_params: raise ValueError( @@ -365,7 +374,12 @@ async def sample_create_reasoning_engine(): "the individual field arguments should be set." ) - request = reasoning_engine_service.CreateReasoningEngineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, reasoning_engine_service.CreateReasoningEngineRequest + ): + request = reasoning_engine_service.CreateReasoningEngineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -376,11 +390,9 @@ async def sample_create_reasoning_engine(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_reasoning_engine, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_reasoning_engine + ] # Certain fields should be provided within the metadata header; # add these here. @@ -476,8 +488,8 @@ async def sample_get_reasoning_engine(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -485,7 +497,10 @@ async def sample_get_reasoning_engine(): "the individual field arguments should be set." ) - request = reasoning_engine_service.GetReasoningEngineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, reasoning_engine_service.GetReasoningEngineRequest): + request = reasoning_engine_service.GetReasoningEngineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -494,11 +509,9 @@ async def sample_get_reasoning_engine(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_reasoning_engine, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_reasoning_engine + ] # Certain fields should be provided within the metadata header; # add these here. @@ -588,8 +601,8 @@ async def sample_list_reasoning_engines(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -597,7 +610,12 @@ async def sample_list_reasoning_engines(): "the individual field arguments should be set." ) - request = reasoning_engine_service.ListReasoningEnginesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, reasoning_engine_service.ListReasoningEnginesRequest + ): + request = reasoning_engine_service.ListReasoningEnginesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -606,11 +624,9 @@ async def sample_list_reasoning_engines(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_reasoning_engines, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_reasoning_engines + ] # Certain fields should be provided within the metadata header; # add these here. @@ -719,8 +735,8 @@ async def sample_delete_reasoning_engine(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -728,7 +744,12 @@ async def sample_delete_reasoning_engine(): "the individual field arguments should be set." ) - request = reasoning_engine_service.DeleteReasoningEngineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, reasoning_engine_service.DeleteReasoningEngineRequest + ): + request = reasoning_engine_service.DeleteReasoningEngineRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -737,11 +758,9 @@ async def sample_delete_reasoning_engine(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_reasoning_engine, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_reasoning_engine + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/client.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/client.py index 7c03ef5e04..dc374de35a 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -543,7 +544,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ReasoningEngineServiceTransport]] = None, + transport: Optional[ + Union[ + str, + ReasoningEngineServiceTransport, + Callable[..., ReasoningEngineServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -555,9 +562,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ReasoningEngineServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ReasoningEngineServiceTransport,Callable[..., ReasoningEngineServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ReasoningEngineServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -669,8 +678,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ReasoningEngineServiceTransport], + Callable[..., ReasoningEngineServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ReasoningEngineServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -764,8 +781,8 @@ def sample_create_reasoning_engine(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, reasoning_engine]) if request is not None and has_flattened_params: raise ValueError( @@ -773,10 +790,8 @@ def sample_create_reasoning_engine(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a reasoning_engine_service.CreateReasoningEngineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, reasoning_engine_service.CreateReasoningEngineRequest ): @@ -886,8 +901,8 @@ def sample_get_reasoning_engine(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -895,10 +910,8 @@ def sample_get_reasoning_engine(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a reasoning_engine_service.GetReasoningEngineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, reasoning_engine_service.GetReasoningEngineRequest): request = reasoning_engine_service.GetReasoningEngineRequest(request) # If we have keyword arguments corresponding to fields on the @@ -998,8 +1011,8 @@ def sample_list_reasoning_engines(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1007,10 +1020,8 @@ def sample_list_reasoning_engines(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a reasoning_engine_service.ListReasoningEnginesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, reasoning_engine_service.ListReasoningEnginesRequest ): @@ -1131,8 +1142,8 @@ def sample_delete_reasoning_engine(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1140,10 +1151,8 @@ def sample_delete_reasoning_engine(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a reasoning_engine_service.DeleteReasoningEngineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, reasoning_engine_service.DeleteReasoningEngineRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/grpc.py index ee85a91b94..3b82d08450 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/grpc.py @@ -56,7 +56,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -76,14 +76,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -93,11 +96,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -124,7 +127,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -165,7 +168,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/grpc_asyncio.py index 39ee7eef5e..2967c5bedf 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -71,7 +73,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -101,7 +102,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -121,15 +122,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -139,11 +143,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -170,7 +174,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -210,7 +214,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -371,6 +377,31 @@ def delete_reasoning_engine( ) return self._stubs["delete_reasoning_engine"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_reasoning_engine: gapic_v1.method_async.wrap_method( + self.create_reasoning_engine, + default_timeout=None, + client_info=client_info, + ), + self.get_reasoning_engine: gapic_v1.method_async.wrap_method( + self.get_reasoning_engine, + default_timeout=None, + client_info=client_info, + ), + self.list_reasoning_engines: gapic_v1.method_async.wrap_method( + self.list_reasoning_engines, + default_timeout=None, + client_info=client_info, + ), + self.delete_reasoning_engine: gapic_v1.method_async.wrap_method( + self.delete_reasoning_engine, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/rest.py index 970810f0a0..fb66f41d9a 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/transports/rest.py @@ -759,10 +759,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1133,10 +1129,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1495,10 +1487,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1873,10 +1861,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2251,10 +2235,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -3589,10 +3569,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4020,10 +3996,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4442,10 +4414,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -4881,10 +4849,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5320,10 +5284,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py index 7c5a077e74..204a041dd1 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -46,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.schedule_service import pagers from google.cloud.aiplatform_v1beta1.types import model_monitoring_service +from google.cloud.aiplatform_v1beta1.types import notebook_service from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import schedule @@ -108,6 +110,18 @@ class ScheduleServiceAsyncClient: ) network_path = staticmethod(ScheduleServiceClient.network_path) parse_network_path = staticmethod(ScheduleServiceClient.parse_network_path) + notebook_execution_job_path = staticmethod( + ScheduleServiceClient.notebook_execution_job_path + ) + parse_notebook_execution_job_path = staticmethod( + ScheduleServiceClient.parse_notebook_execution_job_path + ) + notebook_runtime_template_path = staticmethod( + ScheduleServiceClient.notebook_runtime_template_path + ) + parse_notebook_runtime_template_path = staticmethod( + ScheduleServiceClient.parse_notebook_runtime_template_path + ) pipeline_job_path = staticmethod(ScheduleServiceClient.pipeline_job_path) parse_pipeline_job_path = staticmethod( ScheduleServiceClient.parse_pipeline_job_path @@ -244,7 +258,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ScheduleServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, ScheduleServiceTransport, Callable[..., ScheduleServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -256,9 +274,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ScheduleServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ScheduleServiceTransport,Callable[..., ScheduleServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ScheduleServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -382,8 +402,8 @@ async def sample_create_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, schedule]) if request is not None and has_flattened_params: raise ValueError( @@ -391,7 +411,10 @@ async def sample_create_schedule(): "the individual field arguments should be set." ) - request = schedule_service.CreateScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.CreateScheduleRequest): + request = schedule_service.CreateScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -402,11 +425,9 @@ async def sample_create_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -504,8 +525,8 @@ async def sample_delete_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -513,7 +534,10 @@ async def sample_delete_schedule(): "the individual field arguments should be set." ) - request = schedule_service.DeleteScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.DeleteScheduleRequest): + request = schedule_service.DeleteScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -522,11 +546,9 @@ async def sample_delete_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -619,8 +641,8 @@ async def sample_get_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -628,7 +650,10 @@ async def sample_get_schedule(): "the individual field arguments should be set." ) - request = schedule_service.GetScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.GetScheduleRequest): + request = schedule_service.GetScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -637,11 +662,9 @@ async def sample_get_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -729,8 +752,8 @@ async def sample_list_schedules(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -738,7 +761,10 @@ async def sample_list_schedules(): "the individual field arguments should be set." ) - request = schedule_service.ListSchedulesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.ListSchedulesRequest): + request = schedule_service.ListSchedulesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -747,11 +773,9 @@ async def sample_list_schedules(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_schedules, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_schedules + ] # Certain fields should be provided within the metadata header; # add these here. @@ -838,8 +862,8 @@ async def sample_pause_schedule(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -847,7 +871,10 @@ async def sample_pause_schedule(): "the individual field arguments should be set." ) - request = schedule_service.PauseScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.PauseScheduleRequest): + request = schedule_service.PauseScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -856,11 +883,9 @@ async def sample_pause_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.pause_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.pause_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -954,8 +979,8 @@ async def sample_resume_schedule(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, catch_up]) if request is not None and has_flattened_params: raise ValueError( @@ -963,7 +988,10 @@ async def sample_resume_schedule(): "the individual field arguments should be set." ) - request = schedule_service.ResumeScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.ResumeScheduleRequest): + request = schedule_service.ResumeScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -974,11 +1002,9 @@ async def sample_resume_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.resume_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.resume_schedule + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1084,8 +1110,8 @@ async def sample_update_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([schedule, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1093,7 +1119,10 @@ async def sample_update_schedule(): "the individual field arguments should be set." ) - request = schedule_service.UpdateScheduleRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, schedule_service.UpdateScheduleRequest): + request = schedule_service.UpdateScheduleRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1104,11 +1133,9 @@ async def sample_update_schedule(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_schedule, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_schedule + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py index f6ff65c41d..e93ad36701 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -51,6 +52,7 @@ from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.schedule_service import pagers from google.cloud.aiplatform_v1beta1.types import model_monitoring_service +from google.cloud.aiplatform_v1beta1.types import notebook_service from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import schedule @@ -423,6 +425,50 @@ def parse_network_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def notebook_execution_job_path( + project: str, + location: str, + notebook_execution_job: str, + ) -> str: + """Returns a fully-qualified notebook_execution_job string.""" + return "projects/{project}/locations/{location}/notebookExecutionJobs/{notebook_execution_job}".format( + project=project, + location=location, + notebook_execution_job=notebook_execution_job, + ) + + @staticmethod + def parse_notebook_execution_job_path(path: str) -> Dict[str, str]: + """Parses a notebook_execution_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/notebookExecutionJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def notebook_runtime_template_path( + project: str, + location: str, + notebook_runtime_template: str, + ) -> str: + """Returns a fully-qualified notebook_runtime_template string.""" + return "projects/{project}/locations/{location}/notebookRuntimeTemplates/{notebook_runtime_template}".format( + project=project, + location=location, + notebook_runtime_template=notebook_runtime_template, + ) + + @staticmethod + def parse_notebook_runtime_template_path(path: str) -> Dict[str, str]: + """Parses a notebook_runtime_template path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/notebookRuntimeTemplates/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def pipeline_job_path( project: str, @@ -792,7 +838,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ScheduleServiceTransport]] = None, + transport: Optional[ + Union[ + str, ScheduleServiceTransport, Callable[..., ScheduleServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -804,9 +854,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ScheduleServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,ScheduleServiceTransport,Callable[..., ScheduleServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ScheduleServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -918,8 +970,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[ScheduleServiceTransport], Callable[..., ScheduleServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ScheduleServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -1008,8 +1067,8 @@ def sample_create_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, schedule]) if request is not None and has_flattened_params: raise ValueError( @@ -1017,10 +1076,8 @@ def sample_create_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.CreateScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.CreateScheduleRequest): request = schedule_service.CreateScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1130,8 +1187,8 @@ def sample_delete_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1139,10 +1196,8 @@ def sample_delete_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.DeleteScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.DeleteScheduleRequest): request = schedule_service.DeleteScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1245,8 +1300,8 @@ def sample_get_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1254,10 +1309,8 @@ def sample_get_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.GetScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.GetScheduleRequest): request = schedule_service.GetScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1355,8 +1408,8 @@ def sample_list_schedules(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1364,10 +1417,8 @@ def sample_list_schedules(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.ListSchedulesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.ListSchedulesRequest): request = schedule_service.ListSchedulesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1464,8 +1515,8 @@ def sample_pause_schedule(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1473,10 +1524,8 @@ def sample_pause_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.PauseScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.PauseScheduleRequest): request = schedule_service.PauseScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1580,8 +1629,8 @@ def sample_resume_schedule(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name, catch_up]) if request is not None and has_flattened_params: raise ValueError( @@ -1589,10 +1638,8 @@ def sample_resume_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.ResumeScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.ResumeScheduleRequest): request = schedule_service.ResumeScheduleRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1710,8 +1757,8 @@ def sample_update_schedule(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([schedule, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1719,10 +1766,8 @@ def sample_update_schedule(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a schedule_service.UpdateScheduleRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, schedule_service.UpdateScheduleRequest): request = schedule_service.UpdateScheduleRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/grpc.py index c743dcc45f..abd279d319 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/grpc.py @@ -60,7 +60,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -80,14 +80,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -97,11 +100,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -128,7 +131,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -169,7 +172,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/grpc_asyncio.py index 8ab79ee908..cf1d29a7fb 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -75,7 +77,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -105,7 +106,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -125,15 +126,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -143,11 +147,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -174,7 +178,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -214,7 +218,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -468,6 +474,46 @@ def update_schedule( ) return self._stubs["update_schedule"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_schedule: gapic_v1.method_async.wrap_method( + self.create_schedule, + default_timeout=None, + client_info=client_info, + ), + self.delete_schedule: gapic_v1.method_async.wrap_method( + self.delete_schedule, + default_timeout=None, + client_info=client_info, + ), + self.get_schedule: gapic_v1.method_async.wrap_method( + self.get_schedule, + default_timeout=None, + client_info=client_info, + ), + self.list_schedules: gapic_v1.method_async.wrap_method( + self.list_schedules, + default_timeout=None, + client_info=client_info, + ), + self.pause_schedule: gapic_v1.method_async.wrap_method( + self.pause_schedule, + default_timeout=None, + client_info=client_info, + ), + self.resume_schedule: gapic_v1.method_async.wrap_method( + self.resume_schedule, + default_timeout=None, + client_info=client_info, + ), + self.update_schedule: gapic_v1.method_async.wrap_method( + self.update_schedule, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/rest.py index 33ef3a10ca..78b46a7a52 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/transports/rest.py @@ -816,10 +816,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1190,10 +1186,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1552,10 +1544,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1930,10 +1918,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2308,10 +2292,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -3904,10 +3884,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4335,10 +4311,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4757,10 +4729,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5196,10 +5164,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5635,10 +5599,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index d153edb832..cf45a9c5a1 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -217,7 +218,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + SpecialistPoolServiceTransport, + Callable[..., SpecialistPoolServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -229,9 +236,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.SpecialistPoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,SpecialistPoolServiceTransport,Callable[..., SpecialistPoolServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the SpecialistPoolServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -367,8 +376,8 @@ async def sample_create_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: raise ValueError( @@ -376,7 +385,10 @@ async def sample_create_specialist_pool(): "the individual field arguments should be set." ) - request = specialist_pool_service.CreateSpecialistPoolRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.CreateSpecialistPoolRequest): + request = specialist_pool_service.CreateSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -387,11 +399,9 @@ async def sample_create_specialist_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_specialist_pool, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_specialist_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -495,8 +505,8 @@ async def sample_get_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -504,7 +514,10 @@ async def sample_get_specialist_pool(): "the individual field arguments should be set." ) - request = specialist_pool_service.GetSpecialistPoolRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.GetSpecialistPoolRequest): + request = specialist_pool_service.GetSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -513,11 +526,9 @@ async def sample_get_specialist_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_specialist_pool, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_specialist_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -607,8 +618,8 @@ async def sample_list_specialist_pools(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -616,7 +627,10 @@ async def sample_list_specialist_pools(): "the individual field arguments should be set." ) - request = specialist_pool_service.ListSpecialistPoolsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.ListSpecialistPoolsRequest): + request = specialist_pool_service.ListSpecialistPoolsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -625,11 +639,9 @@ async def sample_list_specialist_pools(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_specialist_pools, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_specialist_pools + ] # Certain fields should be provided within the metadata header; # add these here. @@ -739,8 +751,8 @@ async def sample_delete_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -748,7 +760,10 @@ async def sample_delete_specialist_pool(): "the individual field arguments should be set." ) - request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.DeleteSpecialistPoolRequest): + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -757,11 +772,9 @@ async def sample_delete_specialist_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_specialist_pool, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_specialist_pool + ] # Certain fields should be provided within the metadata header; # add these here. @@ -878,8 +891,8 @@ async def sample_update_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -887,7 +900,10 @@ async def sample_update_specialist_pool(): "the individual field arguments should be set." ) - request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, specialist_pool_service.UpdateSpecialistPoolRequest): + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -898,11 +914,9 @@ async def sample_update_specialist_pool(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_specialist_pool, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_specialist_pool + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index a78269b665..0925ab6e75 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -547,7 +548,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, SpecialistPoolServiceTransport]] = None, + transport: Optional[ + Union[ + str, + SpecialistPoolServiceTransport, + Callable[..., SpecialistPoolServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -559,9 +566,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, SpecialistPoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,SpecialistPoolServiceTransport,Callable[..., SpecialistPoolServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the SpecialistPoolServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -673,8 +682,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[SpecialistPoolServiceTransport], + Callable[..., SpecialistPoolServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., SpecialistPoolServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -775,8 +792,8 @@ def sample_create_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: raise ValueError( @@ -784,10 +801,8 @@ def sample_create_specialist_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.CreateSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.CreateSpecialistPoolRequest): request = specialist_pool_service.CreateSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the @@ -903,8 +918,8 @@ def sample_get_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -912,10 +927,8 @@ def sample_get_specialist_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.GetSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.GetSpecialistPoolRequest): request = specialist_pool_service.GetSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1015,8 +1028,8 @@ def sample_list_specialist_pools(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1024,10 +1037,8 @@ def sample_list_specialist_pools(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.ListSpecialistPoolsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.ListSpecialistPoolsRequest): request = specialist_pool_service.ListSpecialistPoolsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1147,8 +1158,8 @@ def sample_delete_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1156,10 +1167,8 @@ def sample_delete_specialist_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.DeleteSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.DeleteSpecialistPoolRequest): request = specialist_pool_service.DeleteSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1286,8 +1295,8 @@ def sample_update_specialist_pool(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1295,10 +1304,8 @@ def sample_update_specialist_pool(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.UpdateSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, specialist_pool_service.UpdateSpecialistPoolRequest): request = specialist_pool_service.UpdateSpecialistPoolRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index de4f1b42d3..47322c234b 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -61,7 +61,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -81,14 +81,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -98,11 +101,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -129,7 +132,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -170,7 +173,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py index 99a60e9a54..f75843c83f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -76,7 +78,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -106,7 +107,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -126,15 +127,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -144,11 +148,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -175,7 +179,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -215,7 +219,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -406,6 +412,36 @@ def update_specialist_pool( ) return self._stubs["update_specialist_pool"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_specialist_pool: gapic_v1.method_async.wrap_method( + self.create_specialist_pool, + default_timeout=5.0, + client_info=client_info, + ), + self.get_specialist_pool: gapic_v1.method_async.wrap_method( + self.get_specialist_pool, + default_timeout=5.0, + client_info=client_info, + ), + self.list_specialist_pools: gapic_v1.method_async.wrap_method( + self.list_specialist_pools, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_specialist_pool: gapic_v1.method_async.wrap_method( + self.delete_specialist_pool, + default_timeout=5.0, + client_info=client_info, + ), + self.update_specialist_pool: gapic_v1.method_async.wrap_method( + self.update_specialist_pool, + default_timeout=5.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/rest.py index d681d09c98..c398948416 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/rest.py @@ -797,10 +797,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1171,10 +1167,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1533,10 +1525,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -1911,10 +1899,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2289,10 +2273,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -3733,10 +3713,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4164,10 +4140,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -4586,10 +4558,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5025,10 +4993,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5464,10 +5428,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py index 8fe5b4bd6d..9ccc6c0309 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -238,7 +239,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, TensorboardServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + TensorboardServiceTransport, + Callable[..., TensorboardServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -250,9 +257,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.TensorboardServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,TensorboardServiceTransport,Callable[..., TensorboardServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the TensorboardServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -381,8 +390,8 @@ async def sample_create_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -390,7 +399,10 @@ async def sample_create_tensorboard(): "the individual field arguments should be set." ) - request = tensorboard_service.CreateTensorboardRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.CreateTensorboardRequest): + request = tensorboard_service.CreateTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -401,11 +413,9 @@ async def sample_create_tensorboard(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tensorboard, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tensorboard + ] # Certain fields should be provided within the metadata header; # add these here. @@ -502,8 +512,8 @@ async def sample_get_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -511,7 +521,10 @@ async def sample_get_tensorboard(): "the individual field arguments should be set." ) - request = tensorboard_service.GetTensorboardRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.GetTensorboardRequest): + request = tensorboard_service.GetTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -520,11 +533,9 @@ async def sample_get_tensorboard(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tensorboard, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tensorboard + ] # Certain fields should be provided within the metadata header; # add these here. @@ -634,8 +645,8 @@ async def sample_update_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -643,7 +654,10 @@ async def sample_update_tensorboard(): "the individual field arguments should be set." ) - request = tensorboard_service.UpdateTensorboardRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.UpdateTensorboardRequest): + request = tensorboard_service.UpdateTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -654,11 +668,9 @@ async def sample_update_tensorboard(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_tensorboard, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tensorboard + ] # Certain fields should be provided within the metadata header; # add these here. @@ -758,8 +770,8 @@ async def sample_list_tensorboards(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -767,7 +779,10 @@ async def sample_list_tensorboards(): "the individual field arguments should be set." ) - request = tensorboard_service.ListTensorboardsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ListTensorboardsRequest): + request = tensorboard_service.ListTensorboardsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -776,11 +791,9 @@ async def sample_list_tensorboards(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tensorboards, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tensorboards + ] # Certain fields should be provided within the metadata header; # add these here. @@ -889,8 +902,8 @@ async def sample_delete_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -898,7 +911,10 @@ async def sample_delete_tensorboard(): "the individual field arguments should be set." ) - request = tensorboard_service.DeleteTensorboardRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.DeleteTensorboardRequest): + request = tensorboard_service.DeleteTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -907,11 +923,9 @@ async def sample_delete_tensorboard(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_tensorboard, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tensorboard + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1005,8 +1019,8 @@ async def sample_read_tensorboard_usage(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -1014,7 +1028,10 @@ async def sample_read_tensorboard_usage(): "the individual field arguments should be set." ) - request = tensorboard_service.ReadTensorboardUsageRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ReadTensorboardUsageRequest): + request = tensorboard_service.ReadTensorboardUsageRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1023,11 +1040,9 @@ async def sample_read_tensorboard_usage(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_tensorboard_usage, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_tensorboard_usage + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1115,8 +1130,8 @@ async def sample_read_tensorboard_size(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -1124,7 +1139,10 @@ async def sample_read_tensorboard_size(): "the individual field arguments should be set." ) - request = tensorboard_service.ReadTensorboardSizeRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ReadTensorboardSizeRequest): + request = tensorboard_service.ReadTensorboardSizeRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1133,11 +1151,9 @@ async def sample_read_tensorboard_size(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_tensorboard_size, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_tensorboard_size + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1248,8 +1264,8 @@ async def sample_create_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, tensorboard_experiment, tensorboard_experiment_id] ) @@ -1259,7 +1275,12 @@ async def sample_create_tensorboard_experiment(): "the individual field arguments should be set." ) - request = tensorboard_service.CreateTensorboardExperimentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.CreateTensorboardExperimentRequest + ): + request = tensorboard_service.CreateTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1272,11 +1293,9 @@ async def sample_create_tensorboard_experiment(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tensorboard_experiment, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tensorboard_experiment + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1364,8 +1383,8 @@ async def sample_get_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1373,7 +1392,10 @@ async def sample_get_tensorboard_experiment(): "the individual field arguments should be set." ) - request = tensorboard_service.GetTensorboardExperimentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.GetTensorboardExperimentRequest): + request = tensorboard_service.GetTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1382,11 +1404,9 @@ async def sample_get_tensorboard_experiment(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tensorboard_experiment, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tensorboard_experiment + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1489,8 +1509,8 @@ async def sample_update_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1498,7 +1518,12 @@ async def sample_update_tensorboard_experiment(): "the individual field arguments should be set." ) - request = tensorboard_service.UpdateTensorboardExperimentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.UpdateTensorboardExperimentRequest + ): + request = tensorboard_service.UpdateTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1509,11 +1534,9 @@ async def sample_update_tensorboard_experiment(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_tensorboard_experiment, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tensorboard_experiment + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1605,8 +1628,8 @@ async def sample_list_tensorboard_experiments(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1614,7 +1637,12 @@ async def sample_list_tensorboard_experiments(): "the individual field arguments should be set." ) - request = tensorboard_service.ListTensorboardExperimentsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.ListTensorboardExperimentsRequest + ): + request = tensorboard_service.ListTensorboardExperimentsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1623,11 +1651,9 @@ async def sample_list_tensorboard_experiments(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tensorboard_experiments, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tensorboard_experiments + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1736,8 +1762,8 @@ async def sample_delete_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1745,7 +1771,12 @@ async def sample_delete_tensorboard_experiment(): "the individual field arguments should be set." ) - request = tensorboard_service.DeleteTensorboardExperimentRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.DeleteTensorboardExperimentRequest + ): + request = tensorboard_service.DeleteTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1754,11 +1785,9 @@ async def sample_delete_tensorboard_experiment(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_tensorboard_experiment, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tensorboard_experiment + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1879,8 +1908,8 @@ async def sample_create_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) if request is not None and has_flattened_params: raise ValueError( @@ -1888,7 +1917,10 @@ async def sample_create_tensorboard_run(): "the individual field arguments should be set." ) - request = tensorboard_service.CreateTensorboardRunRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.CreateTensorboardRunRequest): + request = tensorboard_service.CreateTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1901,11 +1933,9 @@ async def sample_create_tensorboard_run(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tensorboard_run, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tensorboard_run + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2012,8 +2042,8 @@ async def sample_batch_create_tensorboard_runs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -2021,7 +2051,12 @@ async def sample_batch_create_tensorboard_runs(): "the individual field arguments should be set." ) - request = tensorboard_service.BatchCreateTensorboardRunsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.BatchCreateTensorboardRunsRequest + ): + request = tensorboard_service.BatchCreateTensorboardRunsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2032,11 +2067,9 @@ async def sample_batch_create_tensorboard_runs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_create_tensorboard_runs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_create_tensorboard_runs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2124,8 +2157,8 @@ async def sample_get_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2133,7 +2166,10 @@ async def sample_get_tensorboard_run(): "the individual field arguments should be set." ) - request = tensorboard_service.GetTensorboardRunRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.GetTensorboardRunRequest): + request = tensorboard_service.GetTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2142,11 +2178,9 @@ async def sample_get_tensorboard_run(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tensorboard_run, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tensorboard_run + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2250,8 +2284,8 @@ async def sample_update_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2259,7 +2293,10 @@ async def sample_update_tensorboard_run(): "the individual field arguments should be set." ) - request = tensorboard_service.UpdateTensorboardRunRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.UpdateTensorboardRunRequest): + request = tensorboard_service.UpdateTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2270,11 +2307,9 @@ async def sample_update_tensorboard_run(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_tensorboard_run, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tensorboard_run + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2366,8 +2401,8 @@ async def sample_list_tensorboard_runs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2375,7 +2410,10 @@ async def sample_list_tensorboard_runs(): "the individual field arguments should be set." ) - request = tensorboard_service.ListTensorboardRunsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ListTensorboardRunsRequest): + request = tensorboard_service.ListTensorboardRunsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2384,11 +2422,9 @@ async def sample_list_tensorboard_runs(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tensorboard_runs, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tensorboard_runs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2497,8 +2533,8 @@ async def sample_delete_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2506,7 +2542,10 @@ async def sample_delete_tensorboard_run(): "the individual field arguments should be set." ) - request = tensorboard_service.DeleteTensorboardRunRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.DeleteTensorboardRunRequest): + request = tensorboard_service.DeleteTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2515,11 +2554,9 @@ async def sample_delete_tensorboard_run(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_tensorboard_run, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tensorboard_run + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2636,8 +2673,8 @@ async def sample_batch_create_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -2645,7 +2682,14 @@ async def sample_batch_create_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.BatchCreateTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.BatchCreateTensorboardTimeSeriesRequest + ): + request = tensorboard_service.BatchCreateTensorboardTimeSeriesRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2656,11 +2700,9 @@ async def sample_batch_create_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_create_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_create_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2761,8 +2803,8 @@ async def sample_create_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -2770,7 +2812,12 @@ async def sample_create_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.CreateTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.CreateTensorboardTimeSeriesRequest + ): + request = tensorboard_service.CreateTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2781,11 +2828,9 @@ async def sample_create_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2871,8 +2916,8 @@ async def sample_get_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2880,7 +2925,10 @@ async def sample_get_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.GetTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.GetTensorboardTimeSeriesRequest): + request = tensorboard_service.GetTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -2889,11 +2937,9 @@ async def sample_get_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -2999,8 +3045,8 @@ async def sample_update_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -3008,7 +3054,12 @@ async def sample_update_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.UpdateTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.UpdateTensorboardTimeSeriesRequest + ): + request = tensorboard_service.UpdateTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3019,11 +3070,9 @@ async def sample_update_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3120,8 +3169,8 @@ async def sample_list_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3129,7 +3178,12 @@ async def sample_list_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.ListTensorboardTimeSeriesRequest + ): + request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3138,11 +3192,9 @@ async def sample_list_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3251,8 +3303,8 @@ async def sample_delete_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3260,7 +3312,12 @@ async def sample_delete_tensorboard_time_series(): "the individual field arguments should be set." ) - request = tensorboard_service.DeleteTensorboardTimeSeriesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.DeleteTensorboardTimeSeriesRequest + ): + request = tensorboard_service.DeleteTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3269,11 +3326,9 @@ async def sample_delete_tensorboard_time_series(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_tensorboard_time_series, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tensorboard_time_series + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3377,8 +3432,8 @@ async def sample_batch_read_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -3386,7 +3441,14 @@ async def sample_batch_read_tensorboard_time_series_data(): "the individual field arguments should be set." ) - request = tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest + ): + request = tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3395,11 +3457,9 @@ async def sample_batch_read_tensorboard_time_series_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.batch_read_tensorboard_time_series_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_read_tensorboard_time_series_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3491,8 +3551,8 @@ async def sample_read_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -3500,7 +3560,12 @@ async def sample_read_tensorboard_time_series_data(): "the individual field arguments should be set." ) - request = tensorboard_service.ReadTensorboardTimeSeriesDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.ReadTensorboardTimeSeriesDataRequest + ): + request = tensorboard_service.ReadTensorboardTimeSeriesDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3509,11 +3574,9 @@ async def sample_read_tensorboard_time_series_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_tensorboard_time_series_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_tensorboard_time_series_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3605,8 +3668,8 @@ async def sample_read_tensorboard_blob_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -3614,7 +3677,10 @@ async def sample_read_tensorboard_blob_data(): "the individual field arguments should be set." ) - request = tensorboard_service.ReadTensorboardBlobDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.ReadTensorboardBlobDataRequest): + request = tensorboard_service.ReadTensorboardBlobDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3623,11 +3689,9 @@ async def sample_read_tensorboard_blob_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.read_tensorboard_blob_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.read_tensorboard_blob_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3733,8 +3797,8 @@ async def sample_write_tensorboard_experiment_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, write_run_data_requests]) if request is not None and has_flattened_params: raise ValueError( @@ -3742,7 +3806,12 @@ async def sample_write_tensorboard_experiment_data(): "the individual field arguments should be set." ) - request = tensorboard_service.WriteTensorboardExperimentDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.WriteTensorboardExperimentDataRequest + ): + request = tensorboard_service.WriteTensorboardExperimentDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3753,11 +3822,9 @@ async def sample_write_tensorboard_experiment_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.write_tensorboard_experiment_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.write_tensorboard_experiment_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3868,8 +3935,8 @@ async def sample_write_tensorboard_run_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, time_series_data]) if request is not None and has_flattened_params: raise ValueError( @@ -3877,7 +3944,10 @@ async def sample_write_tensorboard_run_data(): "the individual field arguments should be set." ) - request = tensorboard_service.WriteTensorboardRunDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, tensorboard_service.WriteTensorboardRunDataRequest): + request = tensorboard_service.WriteTensorboardRunDataRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -3888,11 +3958,9 @@ async def sample_write_tensorboard_run_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.write_tensorboard_run_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.write_tensorboard_run_data + ] # Certain fields should be provided within the metadata header; # add these here. @@ -3985,8 +4053,8 @@ async def sample_export_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -3994,7 +4062,14 @@ async def sample_export_tensorboard_time_series_data(): "the individual field arguments should be set." ) - request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, tensorboard_service.ExportTensorboardTimeSeriesDataRequest + ): + request = tensorboard_service.ExportTensorboardTimeSeriesDataRequest( + request + ) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -4003,11 +4078,9 @@ async def sample_export_tensorboard_time_series_data(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_tensorboard_time_series_data, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.export_tensorboard_time_series_data + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py index 4e9e10cacd..48553ea191 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -633,7 +634,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, TensorboardServiceTransport]] = None, + transport: Optional[ + Union[ + str, + TensorboardServiceTransport, + Callable[..., TensorboardServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -645,9 +652,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, TensorboardServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,TensorboardServiceTransport,Callable[..., TensorboardServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the TensorboardServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -759,8 +768,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[TensorboardServiceTransport], + Callable[..., TensorboardServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., TensorboardServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -854,8 +871,8 @@ def sample_create_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -863,10 +880,8 @@ def sample_create_tensorboard(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.CreateTensorboardRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.CreateTensorboardRequest): request = tensorboard_service.CreateTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the @@ -975,8 +990,8 @@ def sample_get_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -984,10 +999,8 @@ def sample_get_tensorboard(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.GetTensorboardRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.GetTensorboardRequest): request = tensorboard_service.GetTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1107,8 +1120,8 @@ def sample_update_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1116,10 +1129,8 @@ def sample_update_tensorboard(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.UpdateTensorboardRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.UpdateTensorboardRequest): request = tensorboard_service.UpdateTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1231,8 +1242,8 @@ def sample_list_tensorboards(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1240,10 +1251,8 @@ def sample_list_tensorboards(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ListTensorboardsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ListTensorboardsRequest): request = tensorboard_service.ListTensorboardsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1362,8 +1371,8 @@ def sample_delete_tensorboard(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1371,10 +1380,8 @@ def sample_delete_tensorboard(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.DeleteTensorboardRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.DeleteTensorboardRequest): request = tensorboard_service.DeleteTensorboardRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1478,8 +1485,8 @@ def sample_read_tensorboard_usage(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -1487,10 +1494,8 @@ def sample_read_tensorboard_usage(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ReadTensorboardUsageRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ReadTensorboardUsageRequest): request = tensorboard_service.ReadTensorboardUsageRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1588,8 +1593,8 @@ def sample_read_tensorboard_size(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -1597,10 +1602,8 @@ def sample_read_tensorboard_size(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ReadTensorboardSizeRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ReadTensorboardSizeRequest): request = tensorboard_service.ReadTensorboardSizeRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1721,8 +1724,8 @@ def sample_create_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any( [parent, tensorboard_experiment, tensorboard_experiment_id] ) @@ -1732,10 +1735,8 @@ def sample_create_tensorboard_experiment(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.CreateTensorboardExperimentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.CreateTensorboardExperimentRequest ): @@ -1841,8 +1842,8 @@ def sample_get_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1850,10 +1851,8 @@ def sample_get_tensorboard_experiment(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.GetTensorboardExperimentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.GetTensorboardExperimentRequest): request = tensorboard_service.GetTensorboardExperimentRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1968,8 +1967,8 @@ def sample_update_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -1977,10 +1976,8 @@ def sample_update_tensorboard_experiment(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.UpdateTensorboardExperimentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.UpdateTensorboardExperimentRequest ): @@ -2088,8 +2085,8 @@ def sample_list_tensorboard_experiments(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2097,10 +2094,8 @@ def sample_list_tensorboard_experiments(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ListTensorboardExperimentsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.ListTensorboardExperimentsRequest ): @@ -2223,8 +2218,8 @@ def sample_delete_tensorboard_experiment(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2232,10 +2227,8 @@ def sample_delete_tensorboard_experiment(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.DeleteTensorboardExperimentRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.DeleteTensorboardExperimentRequest ): @@ -2370,8 +2363,8 @@ def sample_create_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_run, tensorboard_run_id]) if request is not None and has_flattened_params: raise ValueError( @@ -2379,10 +2372,8 @@ def sample_create_tensorboard_run(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.CreateTensorboardRunRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.CreateTensorboardRunRequest): request = tensorboard_service.CreateTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2503,8 +2494,8 @@ def sample_batch_create_tensorboard_runs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -2512,10 +2503,8 @@ def sample_batch_create_tensorboard_runs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.BatchCreateTensorboardRunsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.BatchCreateTensorboardRunsRequest ): @@ -2619,8 +2608,8 @@ def sample_get_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -2628,10 +2617,8 @@ def sample_get_tensorboard_run(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.GetTensorboardRunRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.GetTensorboardRunRequest): request = tensorboard_service.GetTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2745,8 +2732,8 @@ def sample_update_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -2754,10 +2741,8 @@ def sample_update_tensorboard_run(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.UpdateTensorboardRunRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.UpdateTensorboardRunRequest): request = tensorboard_service.UpdateTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2861,8 +2846,8 @@ def sample_list_tensorboard_runs(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2870,10 +2855,8 @@ def sample_list_tensorboard_runs(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ListTensorboardRunsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ListTensorboardRunsRequest): request = tensorboard_service.ListTensorboardRunsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2992,8 +2975,8 @@ def sample_delete_tensorboard_run(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3001,10 +2984,8 @@ def sample_delete_tensorboard_run(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.DeleteTensorboardRunRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.DeleteTensorboardRunRequest): request = tensorboard_service.DeleteTensorboardRunRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3131,8 +3112,8 @@ def sample_batch_create_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, requests]) if request is not None and has_flattened_params: raise ValueError( @@ -3140,10 +3121,8 @@ def sample_batch_create_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.BatchCreateTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.BatchCreateTensorboardTimeSeriesRequest ): @@ -3262,8 +3241,8 @@ def sample_create_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -3271,10 +3250,8 @@ def sample_create_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.CreateTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.CreateTensorboardTimeSeriesRequest ): @@ -3376,8 +3353,8 @@ def sample_get_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3385,10 +3362,8 @@ def sample_get_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.GetTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.GetTensorboardTimeSeriesRequest): request = tensorboard_service.GetTensorboardTimeSeriesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -3506,8 +3481,8 @@ def sample_update_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series, update_mask]) if request is not None and has_flattened_params: raise ValueError( @@ -3515,10 +3490,8 @@ def sample_update_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.UpdateTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.UpdateTensorboardTimeSeriesRequest ): @@ -3631,8 +3604,8 @@ def sample_list_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -3640,10 +3613,8 @@ def sample_list_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ListTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.ListTensorboardTimeSeriesRequest ): @@ -3766,8 +3737,8 @@ def sample_delete_tensorboard_time_series(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -3775,10 +3746,8 @@ def sample_delete_tensorboard_time_series(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.DeleteTensorboardTimeSeriesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.DeleteTensorboardTimeSeriesRequest ): @@ -3896,8 +3865,8 @@ def sample_batch_read_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard]) if request is not None and has_flattened_params: raise ValueError( @@ -3905,10 +3874,8 @@ def sample_batch_read_tensorboard_time_series_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest ): @@ -4016,8 +3983,8 @@ def sample_read_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -4025,10 +3992,8 @@ def sample_read_tensorboard_time_series_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ReadTensorboardTimeSeriesDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.ReadTensorboardTimeSeriesDataRequest ): @@ -4134,8 +4099,8 @@ def sample_read_tensorboard_blob_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -4143,10 +4108,8 @@ def sample_read_tensorboard_blob_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ReadTensorboardBlobDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.ReadTensorboardBlobDataRequest): request = tensorboard_service.ReadTensorboardBlobDataRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4264,8 +4227,8 @@ def sample_write_tensorboard_experiment_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_experiment, write_run_data_requests]) if request is not None and has_flattened_params: raise ValueError( @@ -4273,10 +4236,8 @@ def sample_write_tensorboard_experiment_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.WriteTensorboardExperimentDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.WriteTensorboardExperimentDataRequest ): @@ -4403,8 +4364,8 @@ def sample_write_tensorboard_run_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_run, time_series_data]) if request is not None and has_flattened_params: raise ValueError( @@ -4412,10 +4373,8 @@ def sample_write_tensorboard_run_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.WriteTensorboardRunDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, tensorboard_service.WriteTensorboardRunDataRequest): request = tensorboard_service.WriteTensorboardRunDataRequest(request) # If we have keyword arguments corresponding to fields on the @@ -4522,8 +4481,8 @@ def sample_export_tensorboard_time_series_data(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([tensorboard_time_series]) if request is not None and has_flattened_params: raise ValueError( @@ -4531,10 +4490,8 @@ def sample_export_tensorboard_time_series_data(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a tensorboard_service.ExportTensorboardTimeSeriesDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance( request, tensorboard_service.ExportTensorboardTimeSeriesDataRequest ): diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py index add1e37c8c..c96ec38a2a 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py @@ -66,7 +66,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -86,14 +86,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -103,11 +106,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -134,7 +137,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -175,7 +178,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py index 6a965a4f9b..6b1b064629 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -81,7 +83,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -111,7 +112,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -131,15 +132,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -149,11 +153,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -180,7 +184,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -220,7 +224,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -1183,6 +1189,161 @@ def export_tensorboard_time_series_data( ) return self._stubs["export_tensorboard_time_series_data"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_tensorboard: gapic_v1.method_async.wrap_method( + self.create_tensorboard, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard: gapic_v1.method_async.wrap_method( + self.get_tensorboard, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard: gapic_v1.method_async.wrap_method( + self.update_tensorboard, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboards: gapic_v1.method_async.wrap_method( + self.list_tensorboards, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard: gapic_v1.method_async.wrap_method( + self.delete_tensorboard, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_usage: gapic_v1.method_async.wrap_method( + self.read_tensorboard_usage, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_size: gapic_v1.method_async.wrap_method( + self.read_tensorboard_size, + default_timeout=None, + client_info=client_info, + ), + self.create_tensorboard_experiment: gapic_v1.method_async.wrap_method( + self.create_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_experiment: gapic_v1.method_async.wrap_method( + self.get_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard_experiment: gapic_v1.method_async.wrap_method( + self.update_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_experiments: gapic_v1.method_async.wrap_method( + self.list_tensorboard_experiments, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_experiment: gapic_v1.method_async.wrap_method( + self.delete_tensorboard_experiment, + default_timeout=None, + client_info=client_info, + ), + self.create_tensorboard_run: gapic_v1.method_async.wrap_method( + self.create_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.batch_create_tensorboard_runs: gapic_v1.method_async.wrap_method( + self.batch_create_tensorboard_runs, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_run: gapic_v1.method_async.wrap_method( + self.get_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard_run: gapic_v1.method_async.wrap_method( + self.update_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_runs: gapic_v1.method_async.wrap_method( + self.list_tensorboard_runs, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_run: gapic_v1.method_async.wrap_method( + self.delete_tensorboard_run, + default_timeout=None, + client_info=client_info, + ), + self.batch_create_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.batch_create_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.create_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.create_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.get_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.get_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.update_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.update_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.list_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.list_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.delete_tensorboard_time_series: gapic_v1.method_async.wrap_method( + self.delete_tensorboard_time_series, + default_timeout=None, + client_info=client_info, + ), + self.batch_read_tensorboard_time_series_data: gapic_v1.method_async.wrap_method( + self.batch_read_tensorboard_time_series_data, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_time_series_data: gapic_v1.method_async.wrap_method( + self.read_tensorboard_time_series_data, + default_timeout=None, + client_info=client_info, + ), + self.read_tensorboard_blob_data: gapic_v1.method_async.wrap_method( + self.read_tensorboard_blob_data, + default_timeout=None, + client_info=client_info, + ), + self.write_tensorboard_experiment_data: gapic_v1.method_async.wrap_method( + self.write_tensorboard_experiment_data, + default_timeout=None, + client_info=client_info, + ), + self.write_tensorboard_run_data: gapic_v1.method_async.wrap_method( + self.write_tensorboard_run_data, + default_timeout=None, + client_info=client_info, + ), + self.export_tensorboard_time_series_data: gapic_v1.method_async.wrap_method( + self.export_tensorboard_time_series_data, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/rest.py index 780a6087b8..200440a80e 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/rest.py @@ -1626,10 +1626,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -2000,10 +1996,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2362,10 +2354,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2740,10 +2728,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3118,10 +3102,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -7186,10 +7166,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -7617,10 +7593,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -8039,10 +8011,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -8478,10 +8446,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -8917,10 +8881,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py index ce74571c15..064c8e08dd 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -210,7 +211,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, VertexRagDataServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, + VertexRagDataServiceTransport, + Callable[..., VertexRagDataServiceTransport], + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -222,9 +229,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.VertexRagDataServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,VertexRagDataServiceTransport,Callable[..., VertexRagDataServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the VertexRagDataServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -351,8 +360,8 @@ async def sample_create_rag_corpus(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, rag_corpus]) if request is not None and has_flattened_params: raise ValueError( @@ -360,7 +369,10 @@ async def sample_create_rag_corpus(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.CreateRagCorpusRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.CreateRagCorpusRequest): + request = vertex_rag_data_service.CreateRagCorpusRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -371,11 +383,9 @@ async def sample_create_rag_corpus(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_rag_corpus, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_rag_corpus + ] # Certain fields should be provided within the metadata header; # add these here. @@ -469,8 +479,8 @@ async def sample_get_rag_corpus(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -478,7 +488,10 @@ async def sample_get_rag_corpus(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.GetRagCorpusRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.GetRagCorpusRequest): + request = vertex_rag_data_service.GetRagCorpusRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -487,11 +500,9 @@ async def sample_get_rag_corpus(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_rag_corpus, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_rag_corpus + ] # Certain fields should be provided within the metadata header; # add these here. @@ -581,8 +592,8 @@ async def sample_list_rag_corpora(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -590,7 +601,10 @@ async def sample_list_rag_corpora(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.ListRagCorporaRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.ListRagCorporaRequest): + request = vertex_rag_data_service.ListRagCorporaRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -599,11 +613,9 @@ async def sample_list_rag_corpora(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_rag_corpora, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_rag_corpora + ] # Certain fields should be provided within the metadata header; # add these here. @@ -712,8 +724,8 @@ async def sample_delete_rag_corpus(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -721,7 +733,10 @@ async def sample_delete_rag_corpus(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.DeleteRagCorpusRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.DeleteRagCorpusRequest): + request = vertex_rag_data_service.DeleteRagCorpusRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -730,11 +745,9 @@ async def sample_delete_rag_corpus(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_rag_corpus, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_rag_corpus + ] # Certain fields should be provided within the metadata header; # add these here. @@ -848,8 +861,8 @@ async def sample_upload_rag_file(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, rag_file, upload_rag_file_config]) if request is not None and has_flattened_params: raise ValueError( @@ -857,7 +870,10 @@ async def sample_upload_rag_file(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.UploadRagFileRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.UploadRagFileRequest): + request = vertex_rag_data_service.UploadRagFileRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -870,11 +886,9 @@ async def sample_upload_rag_file(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.upload_rag_file, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.upload_rag_file + ] # Certain fields should be provided within the metadata header; # add these here. @@ -980,8 +994,8 @@ async def sample_import_rag_files(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, import_rag_files_config]) if request is not None and has_flattened_params: raise ValueError( @@ -989,7 +1003,10 @@ async def sample_import_rag_files(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.ImportRagFilesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.ImportRagFilesRequest): + request = vertex_rag_data_service.ImportRagFilesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1000,11 +1017,9 @@ async def sample_import_rag_files(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_rag_files, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.import_rag_files + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1097,8 +1112,8 @@ async def sample_get_rag_file(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1106,7 +1121,10 @@ async def sample_get_rag_file(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.GetRagFileRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.GetRagFileRequest): + request = vertex_rag_data_service.GetRagFileRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1115,11 +1133,9 @@ async def sample_get_rag_file(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_rag_file, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_rag_file + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1209,8 +1225,8 @@ async def sample_list_rag_files(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1218,7 +1234,10 @@ async def sample_list_rag_files(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.ListRagFilesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.ListRagFilesRequest): + request = vertex_rag_data_service.ListRagFilesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1227,11 +1246,9 @@ async def sample_list_rag_files(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_rag_files, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_rag_files + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1340,8 +1357,8 @@ async def sample_delete_rag_file(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1349,7 +1366,10 @@ async def sample_delete_rag_file(): "the individual field arguments should be set." ) - request = vertex_rag_data_service.DeleteRagFileRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.DeleteRagFileRequest): + request = vertex_rag_data_service.DeleteRagFileRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1358,11 +1378,9 @@ async def sample_delete_rag_file(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_rag_file, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_rag_file + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py index 6f13baed26..e118448ba2 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -564,7 +565,13 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, VertexRagDataServiceTransport]] = None, + transport: Optional[ + Union[ + str, + VertexRagDataServiceTransport, + Callable[..., VertexRagDataServiceTransport], + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -576,9 +583,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, VertexRagDataServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,VertexRagDataServiceTransport,Callable[..., VertexRagDataServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the VertexRagDataServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -690,8 +699,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[VertexRagDataServiceTransport], + Callable[..., VertexRagDataServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., VertexRagDataServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -783,8 +800,8 @@ def sample_create_rag_corpus(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, rag_corpus]) if request is not None and has_flattened_params: raise ValueError( @@ -792,10 +809,8 @@ def sample_create_rag_corpus(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.CreateRagCorpusRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.CreateRagCorpusRequest): request = vertex_rag_data_service.CreateRagCorpusRequest(request) # If we have keyword arguments corresponding to fields on the @@ -901,8 +916,8 @@ def sample_get_rag_corpus(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -910,10 +925,8 @@ def sample_get_rag_corpus(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.GetRagCorpusRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.GetRagCorpusRequest): request = vertex_rag_data_service.GetRagCorpusRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1013,8 +1026,8 @@ def sample_list_rag_corpora(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1022,10 +1035,8 @@ def sample_list_rag_corpora(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.ListRagCorporaRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.ListRagCorporaRequest): request = vertex_rag_data_service.ListRagCorporaRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1144,8 +1155,8 @@ def sample_delete_rag_corpus(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1153,10 +1164,8 @@ def sample_delete_rag_corpus(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.DeleteRagCorpusRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.DeleteRagCorpusRequest): request = vertex_rag_data_service.DeleteRagCorpusRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1280,8 +1289,8 @@ def sample_upload_rag_file(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, rag_file, upload_rag_file_config]) if request is not None and has_flattened_params: raise ValueError( @@ -1289,10 +1298,8 @@ def sample_upload_rag_file(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.UploadRagFileRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.UploadRagFileRequest): request = vertex_rag_data_service.UploadRagFileRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1412,8 +1419,8 @@ def sample_import_rag_files(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, import_rag_files_config]) if request is not None and has_flattened_params: raise ValueError( @@ -1421,10 +1428,8 @@ def sample_import_rag_files(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.ImportRagFilesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.ImportRagFilesRequest): request = vertex_rag_data_service.ImportRagFilesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1529,8 +1534,8 @@ def sample_get_rag_file(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1538,10 +1543,8 @@ def sample_get_rag_file(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.GetRagFileRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.GetRagFileRequest): request = vertex_rag_data_service.GetRagFileRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1641,8 +1644,8 @@ def sample_list_rag_files(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1650,10 +1653,8 @@ def sample_list_rag_files(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.ListRagFilesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.ListRagFilesRequest): request = vertex_rag_data_service.ListRagFilesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1772,8 +1773,8 @@ def sample_delete_rag_file(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1781,10 +1782,8 @@ def sample_delete_rag_file(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_data_service.DeleteRagFileRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_data_service.DeleteRagFileRequest): request = vertex_rag_data_service.DeleteRagFileRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py index a8af2ddbc6..11b9e7599d 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc.py @@ -56,7 +56,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -76,14 +76,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -93,11 +96,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -124,7 +127,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -165,7 +168,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py index fb5ef8e06d..157ff346a7 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -71,7 +73,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -101,7 +102,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -121,15 +122,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -139,11 +143,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -170,7 +174,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -210,7 +214,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -516,6 +522,56 @@ def delete_rag_file( ) return self._stubs["delete_rag_file"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_rag_corpus: gapic_v1.method_async.wrap_method( + self.create_rag_corpus, + default_timeout=None, + client_info=client_info, + ), + self.get_rag_corpus: gapic_v1.method_async.wrap_method( + self.get_rag_corpus, + default_timeout=None, + client_info=client_info, + ), + self.list_rag_corpora: gapic_v1.method_async.wrap_method( + self.list_rag_corpora, + default_timeout=None, + client_info=client_info, + ), + self.delete_rag_corpus: gapic_v1.method_async.wrap_method( + self.delete_rag_corpus, + default_timeout=None, + client_info=client_info, + ), + self.upload_rag_file: gapic_v1.method_async.wrap_method( + self.upload_rag_file, + default_timeout=None, + client_info=client_info, + ), + self.import_rag_files: gapic_v1.method_async.wrap_method( + self.import_rag_files, + default_timeout=None, + client_info=client_info, + ), + self.get_rag_file: gapic_v1.method_async.wrap_method( + self.get_rag_file, + default_timeout=None, + client_info=client_info, + ), + self.list_rag_files: gapic_v1.method_async.wrap_method( + self.list_rag_files, + default_timeout=None, + client_info=client_info, + ), + self.delete_rag_file: gapic_v1.method_async.wrap_method( + self.delete_rag_file, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py index 9ac548f5db..922f19288d 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/transports/rest.py @@ -914,10 +914,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1288,10 +1284,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1650,10 +1642,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2028,10 +2016,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2406,10 +2390,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4226,10 +4206,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -4657,10 +4633,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -5079,10 +5051,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -5518,10 +5486,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -5957,10 +5921,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/async_client.py index 61bfe8e1d6..c8e1d49a86 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -64,6 +65,8 @@ class VertexRagServiceAsyncClient: _DEFAULT_ENDPOINT_TEMPLATE = VertexRagServiceClient._DEFAULT_ENDPOINT_TEMPLATE _DEFAULT_UNIVERSE = VertexRagServiceClient._DEFAULT_UNIVERSE + rag_corpus_path = staticmethod(VertexRagServiceClient.rag_corpus_path) + parse_rag_corpus_path = staticmethod(VertexRagServiceClient.parse_rag_corpus_path) common_billing_account_path = staticmethod( VertexRagServiceClient.common_billing_account_path ) @@ -194,7 +197,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, VertexRagServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[ + str, VertexRagServiceTransport, Callable[..., VertexRagServiceTransport] + ] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -206,9 +213,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.VertexRagServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,VertexRagServiceTransport,Callable[..., VertexRagServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the VertexRagServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -285,14 +294,10 @@ async def sample_retrieve_contexts(): client = aiplatform_v1beta1.VertexRagServiceAsyncClient() # Initialize request argument(s) - vertex_rag_store = aiplatform_v1beta1.VertexRagStore() - vertex_rag_store.rag_corpora = ['rag_corpora_value1', 'rag_corpora_value2'] - query = aiplatform_v1beta1.RagQuery() query.text = "text_value" request = aiplatform_v1beta1.RetrieveContextsRequest( - vertex_rag_store=vertex_rag_store, parent="parent_value", query=query, ) @@ -334,8 +339,8 @@ async def sample_retrieve_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, query]) if request is not None and has_flattened_params: raise ValueError( @@ -343,7 +348,10 @@ async def sample_retrieve_contexts(): "the individual field arguments should be set." ) - request = vertex_rag_service.RetrieveContextsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_service.RetrieveContextsRequest): + request = vertex_rag_service.RetrieveContextsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -354,11 +362,9 @@ async def sample_retrieve_contexts(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.retrieve_contexts, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.retrieve_contexts + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/client.py index 4280d980bf..b09149b6c7 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -184,6 +185,28 @@ def transport(self) -> VertexRagServiceTransport: """ return self._transport + @staticmethod + def rag_corpus_path( + project: str, + location: str, + rag_corpus: str, + ) -> str: + """Returns a fully-qualified rag_corpus string.""" + return "projects/{project}/locations/{location}/ragCorpora/{rag_corpus}".format( + project=project, + location=location, + rag_corpus=rag_corpus, + ) + + @staticmethod + def parse_rag_corpus_path(path: str) -> Dict[str, str]: + """Parses a rag_corpus path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/ragCorpora/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -509,7 +532,11 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, VertexRagServiceTransport]] = None, + transport: Optional[ + Union[ + str, VertexRagServiceTransport, Callable[..., VertexRagServiceTransport] + ] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -521,9 +548,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, VertexRagServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,VertexRagServiceTransport,Callable[..., VertexRagServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the VertexRagServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -635,8 +664,16 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[VertexRagServiceTransport], + Callable[..., VertexRagServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., VertexRagServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -678,14 +715,10 @@ def sample_retrieve_contexts(): client = aiplatform_v1beta1.VertexRagServiceClient() # Initialize request argument(s) - vertex_rag_store = aiplatform_v1beta1.VertexRagStore() - vertex_rag_store.rag_corpora = ['rag_corpora_value1', 'rag_corpora_value2'] - query = aiplatform_v1beta1.RagQuery() query.text = "text_value" request = aiplatform_v1beta1.RetrieveContextsRequest( - vertex_rag_store=vertex_rag_store, parent="parent_value", query=query, ) @@ -727,8 +760,8 @@ def sample_retrieve_contexts(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, query]) if request is not None and has_flattened_params: raise ValueError( @@ -736,10 +769,8 @@ def sample_retrieve_contexts(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vertex_rag_service.RetrieveContextsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vertex_rag_service.RetrieveContextsRequest): request = vertex_rag_service.RetrieveContextsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/grpc.py index b184da4455..a15cad9583 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/grpc.py @@ -54,7 +54,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -74,14 +74,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -91,11 +94,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -121,7 +124,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -162,7 +165,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/grpc_asyncio.py index 5aa450bc9c..9a05eb711a 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -69,7 +71,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -99,7 +100,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -119,15 +120,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -137,11 +141,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -167,7 +171,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -207,7 +211,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -265,6 +271,16 @@ def retrieve_contexts( ) return self._stubs["retrieve_contexts"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.retrieve_contexts: gapic_v1.method_async.wrap_method( + self.retrieve_contexts, + default_timeout=None, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/rest.py index dff2d3b314..ccbb9ddf16 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_service/transports/rest.py @@ -1308,10 +1308,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1739,10 +1735,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -2161,10 +2153,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2600,10 +2588,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -3039,10 +3023,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py index 97813d0ffd..6ab69e261c 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -211,7 +212,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, VizierServiceTransport] = "grpc_asyncio", + transport: Optional[ + Union[str, VizierServiceTransport, Callable[..., VizierServiceTransport]] + ] = "grpc_asyncio", client_options: Optional[ClientOptions] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -223,9 +226,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.VizierServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,VizierServiceTransport,Callable[..., VizierServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the VizierServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -350,8 +355,8 @@ async def sample_create_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: raise ValueError( @@ -359,7 +364,10 @@ async def sample_create_study(): "the individual field arguments should be set." ) - request = vizier_service.CreateStudyRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.CreateStudyRequest): + request = vizier_service.CreateStudyRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -370,11 +378,9 @@ async def sample_create_study(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_study, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_study + ] # Certain fields should be provided within the metadata header; # add these here. @@ -455,8 +461,8 @@ async def sample_get_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -464,7 +470,10 @@ async def sample_get_study(): "the individual field arguments should be set." ) - request = vizier_service.GetStudyRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.GetStudyRequest): + request = vizier_service.GetStudyRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -473,11 +482,9 @@ async def sample_get_study(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_study, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_study + ] # Certain fields should be provided within the metadata header; # add these here. @@ -566,8 +573,8 @@ async def sample_list_studies(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -575,7 +582,10 @@ async def sample_list_studies(): "the individual field arguments should be set." ) - request = vizier_service.ListStudiesRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.ListStudiesRequest): + request = vizier_service.ListStudiesRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -584,11 +594,9 @@ async def sample_list_studies(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_studies, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_studies + ] # Certain fields should be provided within the metadata header; # add these here. @@ -672,8 +680,8 @@ async def sample_delete_study(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -681,7 +689,10 @@ async def sample_delete_study(): "the individual field arguments should be set." ) - request = vizier_service.DeleteStudyRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.DeleteStudyRequest): + request = vizier_service.DeleteStudyRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -690,11 +701,9 @@ async def sample_delete_study(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_study, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_study + ] # Certain fields should be provided within the metadata header; # add these here. @@ -775,8 +784,8 @@ async def sample_lookup_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -784,7 +793,10 @@ async def sample_lookup_study(): "the individual field arguments should be set." ) - request = vizier_service.LookupStudyRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.LookupStudyRequest): + request = vizier_service.LookupStudyRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -793,11 +805,9 @@ async def sample_lookup_study(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.lookup_study, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.lookup_study + ] # Certain fields should be provided within the metadata header; # add these here. @@ -884,15 +894,16 @@ async def sample_suggest_trials(): """ # Create or coerce a protobuf request object. - request = vizier_service.SuggestTrialsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.SuggestTrialsRequest): + request = vizier_service.SuggestTrialsRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.suggest_trials, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.suggest_trials + ] # Certain fields should be provided within the metadata header; # add these here. @@ -993,8 +1004,8 @@ async def sample_create_trial(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: raise ValueError( @@ -1002,7 +1013,10 @@ async def sample_create_trial(): "the individual field arguments should be set." ) - request = vizier_service.CreateTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.CreateTrialRequest): + request = vizier_service.CreateTrialRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1013,11 +1027,9 @@ async def sample_create_trial(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_trial, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1103,8 +1115,8 @@ async def sample_get_trial(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1112,7 +1124,10 @@ async def sample_get_trial(): "the individual field arguments should be set." ) - request = vizier_service.GetTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.GetTrialRequest): + request = vizier_service.GetTrialRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1121,11 +1136,9 @@ async def sample_get_trial(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_trial, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1213,8 +1226,8 @@ async def sample_list_trials(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1222,7 +1235,10 @@ async def sample_list_trials(): "the individual field arguments should be set." ) - request = vizier_service.ListTrialsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.ListTrialsRequest): + request = vizier_service.ListTrialsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1231,11 +1247,9 @@ async def sample_list_trials(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_trials, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_trials + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1326,15 +1340,16 @@ async def sample_add_trial_measurement(): """ # Create or coerce a protobuf request object. - request = vizier_service.AddTrialMeasurementRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.AddTrialMeasurementRequest): + request = vizier_service.AddTrialMeasurementRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.add_trial_measurement, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.add_trial_measurement + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1414,15 +1429,16 @@ async def sample_complete_trial(): """ # Create or coerce a protobuf request object. - request = vizier_service.CompleteTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.CompleteTrialRequest): + request = vizier_service.CompleteTrialRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.complete_trial, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.complete_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1496,8 +1512,8 @@ async def sample_delete_trial(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1505,7 +1521,10 @@ async def sample_delete_trial(): "the individual field arguments should be set." ) - request = vizier_service.DeleteTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.DeleteTrialRequest): + request = vizier_service.DeleteTrialRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1514,11 +1533,9 @@ async def sample_delete_trial(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_trial, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1601,15 +1618,16 @@ async def sample_check_trial_early_stopping_state(): """ # Create or coerce a protobuf request object. - request = vizier_service.CheckTrialEarlyStoppingStateRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.CheckTrialEarlyStoppingStateRequest): + request = vizier_service.CheckTrialEarlyStoppingStateRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.check_trial_early_stopping_state, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.check_trial_early_stopping_state + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1697,15 +1715,16 @@ async def sample_stop_trial(): """ # Create or coerce a protobuf request object. - request = vizier_service.StopTrialRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.StopTrialRequest): + request = vizier_service.StopTrialRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.stop_trial, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stop_trial + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1791,8 +1810,8 @@ async def sample_list_optimal_trials(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1800,7 +1819,10 @@ async def sample_list_optimal_trials(): "the individual field arguments should be set." ) - request = vizier_service.ListOptimalTrialsRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vizier_service.ListOptimalTrialsRequest): + request = vizier_service.ListOptimalTrialsRequest(request) # If we have keyword arguments corresponding to fields on the # request, apply these. @@ -1809,11 +1831,9 @@ async def sample_list_optimal_trials(): # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_optimal_trials, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_optimal_trials + ] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py index 6ea0ac83f7..2e6b4f720a 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py @@ -18,6 +18,7 @@ import re from typing import ( Dict, + Callable, Mapping, MutableMapping, MutableSequence, @@ -586,7 +587,9 @@ def __init__( self, *, credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, VizierServiceTransport]] = None, + transport: Optional[ + Union[str, VizierServiceTransport, Callable[..., VizierServiceTransport]] + ] = None, client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -598,9 +601,11 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, VizierServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. + transport (Optional[Union[str,VizierServiceTransport,Callable[..., VizierServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the VizierServiceTransport constructor. + If set to None, a transport is chosen automatically. NOTE: "rest" transport functionality is currently in a beta state (preview). We welcome your feedback via an issue in this library's source repository. @@ -712,8 +717,15 @@ def __init__( api_key_value ) - Transport = type(self).get_transport_class(cast(str, transport)) - self._transport = Transport( + transport_init: Union[ + Type[VizierServiceTransport], Callable[..., VizierServiceTransport] + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., VizierServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( credentials=credentials, credentials_file=self._client_options.credentials_file, host=self._api_endpoint, @@ -803,8 +815,8 @@ def sample_create_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, study]) if request is not None and has_flattened_params: raise ValueError( @@ -812,10 +824,8 @@ def sample_create_study(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.CreateStudyRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.CreateStudyRequest): request = vizier_service.CreateStudyRequest(request) # If we have keyword arguments corresponding to fields on the @@ -908,8 +918,8 @@ def sample_get_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -917,10 +927,8 @@ def sample_get_study(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.GetStudyRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.GetStudyRequest): request = vizier_service.GetStudyRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1019,8 +1027,8 @@ def sample_list_studies(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1028,10 +1036,8 @@ def sample_list_studies(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.ListStudiesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.ListStudiesRequest): request = vizier_service.ListStudiesRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1125,8 +1131,8 @@ def sample_delete_study(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1134,10 +1140,8 @@ def sample_delete_study(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.DeleteStudyRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.DeleteStudyRequest): request = vizier_service.DeleteStudyRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1228,8 +1232,8 @@ def sample_lookup_study(): A message representing a Study. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1237,10 +1241,8 @@ def sample_lookup_study(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.LookupStudyRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.LookupStudyRequest): request = vizier_service.LookupStudyRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1337,10 +1339,8 @@ def sample_suggest_trials(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.SuggestTrialsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.SuggestTrialsRequest): request = vizier_service.SuggestTrialsRequest(request) @@ -1447,8 +1447,8 @@ def sample_create_trial(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, trial]) if request is not None and has_flattened_params: raise ValueError( @@ -1456,10 +1456,8 @@ def sample_create_trial(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.CreateTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.CreateTrialRequest): request = vizier_service.CreateTrialRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1557,8 +1555,8 @@ def sample_get_trial(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1566,10 +1564,8 @@ def sample_get_trial(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.GetTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.GetTrialRequest): request = vizier_service.GetTrialRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1667,8 +1663,8 @@ def sample_list_trials(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -1676,10 +1672,8 @@ def sample_list_trials(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.ListTrialsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.ListTrialsRequest): request = vizier_service.ListTrialsRequest(request) # If we have keyword arguments corresponding to fields on the @@ -1780,10 +1774,8 @@ def sample_add_trial_measurement(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.AddTrialMeasurementRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.AddTrialMeasurementRequest): request = vizier_service.AddTrialMeasurementRequest(request) @@ -1869,10 +1861,8 @@ def sample_complete_trial(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.CompleteTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.CompleteTrialRequest): request = vizier_service.CompleteTrialRequest(request) @@ -1952,8 +1942,8 @@ def sample_delete_trial(): sent along with the request as metadata. """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: raise ValueError( @@ -1961,10 +1951,8 @@ def sample_delete_trial(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.DeleteTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.DeleteTrialRequest): request = vizier_service.DeleteTrialRequest(request) # If we have keyword arguments corresponding to fields on the @@ -2057,10 +2045,8 @@ def sample_check_trial_early_stopping_state(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.CheckTrialEarlyStoppingStateRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.CheckTrialEarlyStoppingStateRequest): request = vizier_service.CheckTrialEarlyStoppingStateRequest(request) @@ -2156,10 +2142,8 @@ def sample_stop_trial(): """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.StopTrialRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.StopTrialRequest): request = vizier_service.StopTrialRequest(request) @@ -2251,8 +2235,8 @@ def sample_list_optimal_trials(): """ # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: raise ValueError( @@ -2260,10 +2244,8 @@ def sample_list_optimal_trials(): "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a vizier_service.ListOptimalTrialsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, vizier_service.ListOptimalTrialsRequest): request = vizier_service.ListOptimalTrialsRequest(request) # If we have keyword arguments corresponding to fields on the diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py index d14bdb6d25..983c01ea84 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py @@ -62,7 +62,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -82,14 +82,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -99,11 +102,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -130,7 +133,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -171,7 +174,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py index 148234d1a9..0ec33f049d 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py @@ -18,6 +18,8 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries from google.api_core import operations_v1 from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -77,7 +79,6 @@ def create_channel( the credentials from the environment. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -107,7 +108,7 @@ def __init__( credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, api_mtls_endpoint: Optional[str] = None, client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, @@ -127,15 +128,18 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. credentials_file (Optional[str]): A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if a ``channel`` instance is provided. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from @@ -145,11 +149,11 @@ def __init__( private key bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. + for the grpc channel. It is ignored if a ``channel`` instance is provided. client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): A callback to provide client certificate bytes and private key bytes, both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -176,7 +180,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if channel: + if isinstance(channel, aio.Channel): # Ignore credentials if a channel was passed. credentials = False # If a channel was explicitly provided, set it. @@ -216,7 +220,9 @@ def __init__( ) if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( self._host, # use the credentials which are saved credentials=self._credentials, @@ -682,6 +688,86 @@ def list_optimal_trials( ) return self._stubs["list_optimal_trials"] + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_study: gapic_v1.method_async.wrap_method( + self.create_study, + default_timeout=5.0, + client_info=client_info, + ), + self.get_study: gapic_v1.method_async.wrap_method( + self.get_study, + default_timeout=5.0, + client_info=client_info, + ), + self.list_studies: gapic_v1.method_async.wrap_method( + self.list_studies, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_study: gapic_v1.method_async.wrap_method( + self.delete_study, + default_timeout=5.0, + client_info=client_info, + ), + self.lookup_study: gapic_v1.method_async.wrap_method( + self.lookup_study, + default_timeout=5.0, + client_info=client_info, + ), + self.suggest_trials: gapic_v1.method_async.wrap_method( + self.suggest_trials, + default_timeout=5.0, + client_info=client_info, + ), + self.create_trial: gapic_v1.method_async.wrap_method( + self.create_trial, + default_timeout=5.0, + client_info=client_info, + ), + self.get_trial: gapic_v1.method_async.wrap_method( + self.get_trial, + default_timeout=5.0, + client_info=client_info, + ), + self.list_trials: gapic_v1.method_async.wrap_method( + self.list_trials, + default_timeout=5.0, + client_info=client_info, + ), + self.add_trial_measurement: gapic_v1.method_async.wrap_method( + self.add_trial_measurement, + default_timeout=5.0, + client_info=client_info, + ), + self.complete_trial: gapic_v1.method_async.wrap_method( + self.complete_trial, + default_timeout=5.0, + client_info=client_info, + ), + self.delete_trial: gapic_v1.method_async.wrap_method( + self.delete_trial, + default_timeout=5.0, + client_info=client_info, + ), + self.check_trial_early_stopping_state: gapic_v1.method_async.wrap_method( + self.check_trial_early_stopping_state, + default_timeout=5.0, + client_info=client_info, + ), + self.stop_trial: gapic_v1.method_async.wrap_method( + self.stop_trial, + default_timeout=5.0, + client_info=client_info, + ), + self.list_optimal_trials: gapic_v1.method_async.wrap_method( + self.list_optimal_trials, + default_timeout=5.0, + client_info=client_info, + ), + } + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/rest.py index d3d280097d..3e3172d03f 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/rest.py @@ -1054,10 +1054,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -1428,10 +1424,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -1790,10 +1782,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -2168,10 +2156,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -2546,10 +2530,6 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", @@ -4937,10 +4917,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:cancel", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", @@ -5368,10 +5344,6 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", @@ -5790,10 +5762,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}", @@ -6229,10 +6197,6 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*}/operations", }, - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*}/operations", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*}/operations", @@ -6668,10 +6632,6 @@ def __call__( "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/operations/*}:wait", }, - { - "method": "post", - "uri": "/v1beta1/{name=projects/*/locations/*/extensions/*/deployments/*/operations/*}:wait", - }, { "method": "post", "uri": "/v1beta1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index d16c74107a..4f96dabbc7 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -44,6 +44,7 @@ Part, SafetyRating, SafetySetting, + SearchEntryPoint, Segment, VideoMetadata, HarmCategory, @@ -107,6 +108,7 @@ SearchDataItemsRequest, SearchDataItemsResponse, UpdateDatasetRequest, + UpdateDatasetVersionRequest, ) from .dataset_version import ( DatasetVersion, @@ -785,6 +787,9 @@ from .notebook_euc_config import ( NotebookEucConfig, ) +from .notebook_execution_job import ( + NotebookExecutionJob, +) from .notebook_idle_shutdown_config import ( NotebookIdleShutdownConfig, ) @@ -799,12 +804,17 @@ from .notebook_service import ( AssignNotebookRuntimeOperationMetadata, AssignNotebookRuntimeRequest, + CreateNotebookExecutionJobRequest, CreateNotebookRuntimeTemplateOperationMetadata, CreateNotebookRuntimeTemplateRequest, + DeleteNotebookExecutionJobRequest, DeleteNotebookRuntimeRequest, DeleteNotebookRuntimeTemplateRequest, + GetNotebookExecutionJobRequest, GetNotebookRuntimeRequest, GetNotebookRuntimeTemplateRequest, + ListNotebookExecutionJobsRequest, + ListNotebookExecutionJobsResponse, ListNotebookRuntimesRequest, ListNotebookRuntimesResponse, ListNotebookRuntimeTemplatesRequest, @@ -815,6 +825,7 @@ UpgradeNotebookRuntimeOperationMetadata, UpgradeNotebookRuntimeRequest, UpgradeNotebookRuntimeResponse, + NotebookExecutionJobView, ) from .openapi import ( Schema, @@ -878,7 +889,6 @@ PipelineState, ) from .prediction_service import ( - ChatCompletionsRequest, CountTokensRequest, CountTokensResponse, DirectPredictRequest, @@ -1141,6 +1151,7 @@ "Part", "SafetyRating", "SafetySetting", + "SearchEntryPoint", "Segment", "VideoMetadata", "HarmCategory", @@ -1192,6 +1203,7 @@ "SearchDataItemsRequest", "SearchDataItemsResponse", "UpdateDatasetRequest", + "UpdateDatasetVersionRequest", "DatasetVersion", "DeployedIndexRef", "DeployedModelRef", @@ -1735,6 +1747,7 @@ "NasTrialDetail", "NetworkSpec", "NotebookEucConfig", + "NotebookExecutionJob", "NotebookIdleShutdownConfig", "NotebookRuntime", "NotebookRuntimeTemplate", @@ -1742,12 +1755,17 @@ "NotebookRuntimeTemplateRef", "AssignNotebookRuntimeOperationMetadata", "AssignNotebookRuntimeRequest", + "CreateNotebookExecutionJobRequest", "CreateNotebookRuntimeTemplateOperationMetadata", "CreateNotebookRuntimeTemplateRequest", + "DeleteNotebookExecutionJobRequest", "DeleteNotebookRuntimeRequest", "DeleteNotebookRuntimeTemplateRequest", + "GetNotebookExecutionJobRequest", "GetNotebookRuntimeRequest", "GetNotebookRuntimeTemplateRequest", + "ListNotebookExecutionJobsRequest", + "ListNotebookExecutionJobsResponse", "ListNotebookRuntimesRequest", "ListNotebookRuntimesResponse", "ListNotebookRuntimeTemplatesRequest", @@ -1758,6 +1776,7 @@ "UpgradeNotebookRuntimeOperationMetadata", "UpgradeNotebookRuntimeRequest", "UpgradeNotebookRuntimeResponse", + "NotebookExecutionJobView", "Schema", "Type", "DeleteOperationMetadata", @@ -1803,7 +1822,6 @@ "ListTrainingPipelinesRequest", "ListTrainingPipelinesResponse", "PipelineState", - "ChatCompletionsRequest", "CountTokensRequest", "CountTokensResponse", "DirectPredictRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/content.py b/google/cloud/aiplatform_v1beta1/types/content.py index a4b7b78bd8..fbd449c5c6 100644 --- a/google/cloud/aiplatform_v1beta1/types/content.py +++ b/google/cloud/aiplatform_v1beta1/types/content.py @@ -42,6 +42,7 @@ "Segment", "GroundingAttribution", "GroundingMetadata", + "SearchEntryPoint", }, ) @@ -824,10 +825,17 @@ class RetrievedContext(proto.Message): class GroundingMetadata(proto.Message): r"""Metadata returned to client when grounding is enabled. + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: web_search_queries (MutableSequence[str]): Optional. Web search queries for the following-up web search. + search_entry_point (google.cloud.aiplatform_v1beta1.types.SearchEntryPoint): + Optional. Google search entry for the + following-up web searches. + + This field is a member of `oneof`_ ``_search_entry_point``. retrieval_queries (MutableSequence[str]): Optional. Queries executed by the retrieval tools. @@ -839,6 +847,12 @@ class GroundingMetadata(proto.Message): proto.STRING, number=1, ) + search_entry_point: "SearchEntryPoint" = proto.Field( + proto.MESSAGE, + number=4, + optional=True, + message="SearchEntryPoint", + ) retrieval_queries: MutableSequence[str] = proto.RepeatedField( proto.STRING, number=3, @@ -852,4 +866,26 @@ class GroundingMetadata(proto.Message): ) +class SearchEntryPoint(proto.Message): + r"""Google search entry point. + + Attributes: + rendered_content (str): + Optional. Web content snippet that can be + embedded in a web page or an app webview. + sdk_blob (bytes): + Optional. Base64 encoded JSON representing + array of tuple. + """ + + rendered_content: str = proto.Field( + proto.STRING, + number=1, + ) + sdk_blob: bytes = proto.Field( + proto.BYTES, + number=2, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index c95cf0a988..0eb73ac2db 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -35,6 +35,7 @@ "CreateDatasetOperationMetadata", "GetDatasetRequest", "UpdateDatasetRequest", + "UpdateDatasetVersionRequest", "ListDatasetsRequest", "ListDatasetsResponse", "DeleteDatasetRequest", @@ -160,6 +161,35 @@ class UpdateDatasetRequest(proto.Message): ) +class UpdateDatasetVersionRequest(proto.Message): + r"""Request message for + [DatasetService.UpdateDatasetVersion][google.cloud.aiplatform.v1beta1.DatasetService.UpdateDatasetVersion]. + + Attributes: + dataset_version (google.cloud.aiplatform_v1beta1.types.DatasetVersion): + Required. The DatasetVersion which replaces + the resource on the server. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. For the + ``FieldMask`` definition, see + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + Updatable fields: + + - ``display_name`` + """ + + dataset_version: gca_dataset_version.DatasetVersion = proto.Field( + proto.MESSAGE, + number=1, + message=gca_dataset_version.DatasetVersion, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + class ListDatasetsRequest(proto.Message): r"""Request message for [DatasetService.ListDatasets][google.cloud.aiplatform.v1beta1.DatasetService.ListDatasets]. diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 55c8e484bc..fa9200892a 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -23,6 +23,7 @@ from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import service_networking from google.protobuf import timestamp_pb2 # type: ignore @@ -123,6 +124,13 @@ class Endpoint(proto.Message): or [enable_private_service_connect][google.cloud.aiplatform.v1beta1.Endpoint.enable_private_service_connect], can be set. + private_service_connect_config (google.cloud.aiplatform_v1beta1.types.PrivateServiceConnectConfig): + Optional. Configuration for private service connect. + + [network][google.cloud.aiplatform.v1beta1.Endpoint.network] + and + [private_service_connect_config][google.cloud.aiplatform.v1beta1.Endpoint.private_service_connect_config] + are mutually exclusive. model_deployment_monitoring_job (str): Output only. Resource name of the Model Monitoring job associated with this Endpoint if monitoring is enabled by @@ -188,6 +196,13 @@ class Endpoint(proto.Message): proto.BOOL, number=17, ) + private_service_connect_config: service_networking.PrivateServiceConnectConfig = ( + proto.Field( + proto.MESSAGE, + number=21, + message=service_networking.PrivateServiceConnectConfig, + ) + ) model_deployment_monitoring_job: str = proto.Field( proto.STRING, number=14, diff --git a/google/cloud/aiplatform_v1beta1/types/evaluation_service.py b/google/cloud/aiplatform_v1beta1/types/evaluation_service.py index fdedf72c71..64529fc687 100644 --- a/google/cloud/aiplatform_v1beta1/types/evaluation_service.py +++ b/google/cloud/aiplatform_v1beta1/types/evaluation_service.py @@ -880,7 +880,7 @@ class RougeSpec(proto.Message): Attributes: rouge_type (str): - Optional. Supported rouge types are rougen[1-9], rougeL and + Optional. Supported rouge types are rougen[1-9], rougeL, and rougeLsum. use_stemmer (bool): Optional. Whether to use stemmer to compute diff --git a/google/cloud/aiplatform_v1beta1/types/extension.py b/google/cloud/aiplatform_v1beta1/types/extension.py index bc021bb6fb..73b8e38e3d 100644 --- a/google/cloud/aiplatform_v1beta1/types/extension.py +++ b/google/cloud/aiplatform_v1beta1/types/extension.py @@ -577,20 +577,20 @@ class CodeInterpreterRuntimeConfig(proto.Message): Attributes: file_input_gcs_bucket (str): - Optional. The GCS bucket for file input of - this Extension. If specified, support input from - the GCS bucket. Vertex Extension Custom Code - Service Agent should be granted file reader to - this bucket. + Optional. The Cloud Storage bucket for file + input of this Extension. If specified, support + input from the Cloud Storage bucket. Vertex + Extension Custom Code Service Agent should be + granted file reader to this bucket. If not specified, the extension will only accept - file contents from request body and reject GCS - file inputs. + file contents from request body and reject Cloud + Storage file inputs. file_output_gcs_bucket (str): - Optional. The GCS bucket for file output of - this Extension. If specified, write all output - files to the GCS bucket. Vertex Extension Custom - Code Service Agent should be granted file writer - to this bucket. + Optional. The Cloud Storage bucket for file + output of this Extension. If specified, write + all output files to the Cloud Storage bucket. + Vertex Extension Custom Code Service Agent + should be granted file writer to this bucket. If not specified, the file content will be output in response body. """ @@ -609,16 +609,25 @@ class VertexAISearchRuntimeConfig(proto.Message): Attributes: serving_config_name (str): - Required. Vertext AI Search serving config name. Format: + [Deprecated] Please use app_id instead. Vertex AI Search + serving config name. Format: ``projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}`` - or - ``projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`` + app_id (str): + Vertex AI Search App ID. This is used to construct the + search request. By setting this app_id, API will construct + the serving config which is required to call search API for + the user. The app_id and serving_config_name cannot both be + empty at the same time. """ serving_config_name: str = proto.Field( proto.STRING, number=1, ) + app_id: str = proto.Field( + proto.STRING, + number=2, + ) code_interpreter_runtime_config: CodeInterpreterRuntimeConfig = proto.Field( proto.MESSAGE, diff --git a/google/cloud/aiplatform_v1beta1/types/feature_group.py b/google/cloud/aiplatform_v1beta1/types/feature_group.py index c5c85ff0a4..cd7e7852de 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature_group.py +++ b/google/cloud/aiplatform_v1beta1/types/feature_group.py @@ -40,8 +40,9 @@ class FeatureGroup(proto.Message): big_query (google.cloud.aiplatform_v1beta1.types.FeatureGroup.BigQuery): Indicates that features for this group come from BigQuery Table/View. By default treats the source as a sparse time - series source, which is required to have an entity_id and a - feature_timestamp column in the source. + series source. The BigQuery source table or view must have + at least one entity ID column and a column named + ``feature_timestamp``. This field is a member of `oneof`_ ``source``. name (str): diff --git a/google/cloud/aiplatform_v1beta1/types/feature_registry_service.py b/google/cloud/aiplatform_v1beta1/types/feature_registry_service.py index 6c2f889062..ff272bc5c9 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature_registry_service.py +++ b/google/cloud/aiplatform_v1beta1/types/feature_registry_service.py @@ -49,7 +49,7 @@ class CreateFeatureGroupRequest(proto.Message): parent (str): Required. The resource name of the Location to create FeatureGroups. Format: - ``projects/{project}/locations/{location}'`` + ``projects/{project}/locations/{location}`` feature_group (google.cloud.aiplatform_v1beta1.types.FeatureGroup): Required. The FeatureGroup to create. feature_group_id (str): diff --git a/google/cloud/aiplatform_v1beta1/types/index_service.py b/google/cloud/aiplatform_v1beta1/types/index_service.py index fae6da8f99..4193529f5f 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_service.py +++ b/google/cloud/aiplatform_v1beta1/types/index_service.py @@ -401,6 +401,8 @@ class RecordErrorType(proto.Enum): specified. INVALID_ENCODING (13): File is not in UTF_8 format. + INVALID_TOKEN_VALUE (15): + Token restrict value is invalid. """ ERROR_TYPE_UNSPECIFIED = 0 EMPTY_LINE = 1 @@ -416,6 +418,7 @@ class RecordErrorType(proto.Enum): MULTIPLE_VALUES = 11 INVALID_NUMERIC_VALUE = 12 INVALID_ENCODING = 13 + INVALID_TOKEN_VALUE = 15 error_type: "NearestNeighborSearchOperationMetadata.RecordError.RecordErrorType" = proto.Field( proto.ENUM, diff --git a/google/cloud/aiplatform_v1beta1/types/match_service.py b/google/cloud/aiplatform_v1beta1/types/match_service.py index 044b2055f2..456ebcb993 100644 --- a/google/cloud/aiplatform_v1beta1/types/match_service.py +++ b/google/cloud/aiplatform_v1beta1/types/match_service.py @@ -157,7 +157,7 @@ class Neighbor(proto.Message): are populated. distance (float): The distance between the neighbor and the - query vector. + dense embedding query. """ datapoint: index.IndexDatapoint = proto.Field( diff --git a/google/cloud/aiplatform_v1beta1/types/notebook_execution_job.py b/google/cloud/aiplatform_v1beta1/types/notebook_execution_job.py new file mode 100644 index 0000000000..6fa6caf27d --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/notebook_execution_job.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.aiplatform_v1beta1.types import job_state as gca_job_state +from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "NotebookExecutionJob", + }, +) + + +class NotebookExecutionJob(proto.Message): + r"""NotebookExecutionJob represents an instance of a notebook + execution. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + dataform_repository_source (google.cloud.aiplatform_v1beta1.types.NotebookExecutionJob.DataformRepositorySource): + The Dataform Repository pointing to a single + file notebook repository. + + This field is a member of `oneof`_ ``notebook_source``. + gcs_notebook_source (google.cloud.aiplatform_v1beta1.types.NotebookExecutionJob.GcsNotebookSource): + The Cloud Storage url pointing to the ipynb file. Format: + ``gs://bucket/notebook_file.ipynb`` + + This field is a member of `oneof`_ ``notebook_source``. + notebook_runtime_template_resource_name (str): + The NotebookRuntimeTemplate to source compute + configuration from. + + This field is a member of `oneof`_ ``environment_spec``. + gcs_output_uri (str): + The Cloud Storage location to upload the result to. Format: + ``gs://bucket-name`` + + This field is a member of `oneof`_ ``execution_sink``. + execution_user (str): + The user email to run the execution as. Only + supported by Colab runtimes. + + This field is a member of `oneof`_ ``execution_identity``. + service_account (str): + The service account to run the execution as. + + This field is a member of `oneof`_ ``execution_identity``. + name (str): + Output only. The resource name of this NotebookExecutionJob. + Format: + ``projects/{project_id}/locations/{location}/notebookExecutionJobs/{job_id}`` + display_name (str): + The display name of the NotebookExecutionJob. + The name can be up to 128 characters long and + can consist of any UTF-8 characters. + execution_timeout (google.protobuf.duration_pb2.Duration): + Max running time of the execution job in + seconds (default 86400s / 24 hrs). + schedule_resource_name (str): + Output only. The Schedule resource name if this job is + triggered by one. Format: + ``projects/{project_id}/locations/{location}/schedules/{schedule_id}`` + job_state (google.cloud.aiplatform_v1beta1.types.JobState): + Output only. The state of the + NotebookExecutionJob. + status (google.rpc.status_pb2.Status): + Output only. Populated when the + NotebookExecutionJob is completed. When there is + an error during notebook execution, the error + details are populated. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + NotebookExecutionJob was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + NotebookExecutionJob was most recently updated. + """ + + class DataformRepositorySource(proto.Message): + r"""The Dataform Repository containing the input notebook. + + Attributes: + dataform_repository_resource_name (str): + The resource name of the Dataform Repository. Format: + ``projects/{project_id}/locations/{location}/repositories/{repository_id}`` + commit_sha (str): + The commit SHA to read repository with. If + unset, the file will be read at HEAD. + """ + + dataform_repository_resource_name: str = proto.Field( + proto.STRING, + number=1, + ) + commit_sha: str = proto.Field( + proto.STRING, + number=2, + ) + + class GcsNotebookSource(proto.Message): + r"""The Cloud Storage uri for the input notebook. + + Attributes: + uri (str): + The Cloud Storage uri pointing to the ipynb file. Format: + ``gs://bucket/notebook_file.ipynb`` + generation (str): + The version of the Cloud Storage object to + read. If unset, the current version of the + object is read. See + https://cloud.google.com/storage/docs/metadata#generation-number. + """ + + uri: str = proto.Field( + proto.STRING, + number=1, + ) + generation: str = proto.Field( + proto.STRING, + number=2, + ) + + dataform_repository_source: DataformRepositorySource = proto.Field( + proto.MESSAGE, + number=3, + oneof="notebook_source", + message=DataformRepositorySource, + ) + gcs_notebook_source: GcsNotebookSource = proto.Field( + proto.MESSAGE, + number=4, + oneof="notebook_source", + message=GcsNotebookSource, + ) + notebook_runtime_template_resource_name: str = proto.Field( + proto.STRING, + number=14, + oneof="environment_spec", + ) + gcs_output_uri: str = proto.Field( + proto.STRING, + number=8, + oneof="execution_sink", + ) + execution_user: str = proto.Field( + proto.STRING, + number=9, + oneof="execution_identity", + ) + service_account: str = proto.Field( + proto.STRING, + number=18, + oneof="execution_identity", + ) + name: str = proto.Field( + proto.STRING, + number=1, + ) + display_name: str = proto.Field( + proto.STRING, + number=2, + ) + execution_timeout: duration_pb2.Duration = proto.Field( + proto.MESSAGE, + number=5, + message=duration_pb2.Duration, + ) + schedule_resource_name: str = proto.Field( + proto.STRING, + number=6, + ) + job_state: gca_job_state.JobState = proto.Field( + proto.ENUM, + number=10, + enum=gca_job_state.JobState, + ) + status: status_pb2.Status = proto.Field( + proto.MESSAGE, + number=11, + message=status_pb2.Status, + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=12, + message=timestamp_pb2.Timestamp, + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=13, + message=timestamp_pb2.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/notebook_runtime.py b/google/cloud/aiplatform_v1beta1/types/notebook_runtime.py index 044e0722ab..0470391d22 100644 --- a/google/cloud/aiplatform_v1beta1/types/notebook_runtime.py +++ b/google/cloud/aiplatform_v1beta1/types/notebook_runtime.py @@ -65,7 +65,7 @@ class NotebookRuntimeTemplate(proto.Message): Attributes: name (str): - Output only. The resource name of the + The resource name of the NotebookRuntimeTemplate. display_name (str): Required. The display name of the diff --git a/google/cloud/aiplatform_v1beta1/types/notebook_service.py b/google/cloud/aiplatform_v1beta1/types/notebook_service.py index 68620b3dd9..2a76f8dd7f 100644 --- a/google/cloud/aiplatform_v1beta1/types/notebook_service.py +++ b/google/cloud/aiplatform_v1beta1/types/notebook_service.py @@ -19,6 +19,9 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import ( + notebook_execution_job as gca_notebook_execution_job, +) from google.cloud.aiplatform_v1beta1.types import ( notebook_runtime as gca_notebook_runtime, ) @@ -29,6 +32,7 @@ __protobuf__ = proto.module( package="google.cloud.aiplatform.v1beta1", manifest={ + "NotebookExecutionJobView", "CreateNotebookRuntimeTemplateRequest", "CreateNotebookRuntimeTemplateOperationMetadata", "GetNotebookRuntimeTemplateRequest", @@ -47,10 +51,33 @@ "StartNotebookRuntimeRequest", "StartNotebookRuntimeOperationMetadata", "StartNotebookRuntimeResponse", + "CreateNotebookExecutionJobRequest", + "GetNotebookExecutionJobRequest", + "ListNotebookExecutionJobsRequest", + "ListNotebookExecutionJobsResponse", + "DeleteNotebookExecutionJobRequest", }, ) +class NotebookExecutionJobView(proto.Enum): + r"""Views for Get/List NotebookExecutionJob + + Values: + NOTEBOOK_EXECUTION_JOB_VIEW_UNSPECIFIED (0): + When unspecified, the API defaults to the + BASIC view. + NOTEBOOK_EXECUTION_JOB_VIEW_BASIC (1): + Includes all fields except for direct + notebook inputs. + NOTEBOOK_EXECUTION_JOB_VIEW_FULL (2): + Includes all fields. + """ + NOTEBOOK_EXECUTION_JOB_VIEW_UNSPECIFIED = 0 + NOTEBOOK_EXECUTION_JOB_VIEW_BASIC = 1 + NOTEBOOK_EXECUTION_JOB_VIEW_FULL = 2 + + class CreateNotebookRuntimeTemplateRequest(proto.Message): r"""Request message for [NotebookService.CreateNotebookRuntimeTemplate][google.cloud.aiplatform.v1beta1.NotebookService.CreateNotebookRuntimeTemplate]. @@ -593,4 +620,179 @@ class StartNotebookRuntimeResponse(proto.Message): """ +class CreateNotebookExecutionJobRequest(proto.Message): + r"""Request message for [NotebookService.CreateNotebookExecutionJob] + + Attributes: + parent (str): + Required. The resource name of the Location to create the + NotebookExecutionJob. Format: + ``projects/{project}/locations/{location}`` + notebook_execution_job (google.cloud.aiplatform_v1beta1.types.NotebookExecutionJob): + Required. The NotebookExecutionJob to create. + notebook_execution_job_id (str): + Optional. User specified ID for the + NotebookExecutionJob. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + notebook_execution_job: gca_notebook_execution_job.NotebookExecutionJob = ( + proto.Field( + proto.MESSAGE, + number=2, + message=gca_notebook_execution_job.NotebookExecutionJob, + ) + ) + notebook_execution_job_id: str = proto.Field( + proto.STRING, + number=3, + ) + + +class GetNotebookExecutionJobRequest(proto.Message): + r"""Request message for [NotebookService.GetNotebookExecutionJob] + + Attributes: + name (str): + Required. The name of the + NotebookExecutionJob resource. + view (google.cloud.aiplatform_v1beta1.types.NotebookExecutionJobView): + Optional. The NotebookExecutionJob view. + Defaults to BASIC. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + view: "NotebookExecutionJobView" = proto.Field( + proto.ENUM, + number=6, + enum="NotebookExecutionJobView", + ) + + +class ListNotebookExecutionJobsRequest(proto.Message): + r"""Request message for [NotebookService.ListNotebookExecutionJobs] + + Attributes: + parent (str): + Required. The resource name of the Location from which to + list the NotebookExecutionJobs. Format: + ``projects/{project}/locations/{location}`` + filter (str): + Optional. An expression for filtering the results of the + request. For field names both snake_case and camelCase are + supported. + + - ``notebookExecutionJob`` supports = and !=. + ``notebookExecutionJob`` represents the + NotebookExecutionJob ID. + - ``displayName`` supports = and != and regex. + - ``schedule`` supports = and != and regex. + + Some examples: + + - ``notebookExecutionJob="123"`` + - ``notebookExecutionJob="my-execution-job"`` + - ``displayName="myDisplayName"`` and + ``displayName=~"myDisplayNameRegex"`` + page_size (int): + Optional. The standard list page size. + page_token (str): + Optional. The standard list page token. Typically obtained + via [ListNotebookExecutionJobs.next_page_token][] of the + previous + [NotebookService.ListNotebookExecutionJobs][google.cloud.aiplatform.v1beta1.NotebookService.ListNotebookExecutionJobs] + call. + order_by (str): + Optional. A comma-separated list of fields to order by, + sorted in ascending order. Use "desc" after a field name for + descending. Supported fields: + + - ``display_name`` + - ``create_time`` + - ``update_time`` + + Example: ``display_name, create_time desc``. + view (google.cloud.aiplatform_v1beta1.types.NotebookExecutionJobView): + Optional. The NotebookExecutionJob view. + Defaults to BASIC. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + filter: str = proto.Field( + proto.STRING, + number=2, + ) + page_size: int = proto.Field( + proto.INT32, + number=3, + ) + page_token: str = proto.Field( + proto.STRING, + number=4, + ) + order_by: str = proto.Field( + proto.STRING, + number=5, + ) + view: "NotebookExecutionJobView" = proto.Field( + proto.ENUM, + number=6, + enum="NotebookExecutionJobView", + ) + + +class ListNotebookExecutionJobsResponse(proto.Message): + r"""Response message for [NotebookService.CreateNotebookExecutionJob] + + Attributes: + notebook_execution_jobs (MutableSequence[google.cloud.aiplatform_v1beta1.types.NotebookExecutionJob]): + List of NotebookExecutionJobs in the + requested page. + next_page_token (str): + A token to retrieve next page of results. Pass to + [ListNotebookExecutionJobs.page_token][] to obtain that + page. + """ + + @property + def raw_page(self): + return self + + notebook_execution_jobs: MutableSequence[ + gca_notebook_execution_job.NotebookExecutionJob + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gca_notebook_execution_job.NotebookExecutionJob, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class DeleteNotebookExecutionJobRequest(proto.Message): + r"""Request message for [NotebookService.DeleteNotebookExecutionJob] + + Attributes: + name (str): + Required. The name of the + NotebookExecutionJob resource to be deleted. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index 464fd65386..eb2eba2054 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -51,7 +51,6 @@ "CountTokensResponse", "GenerateContentRequest", "GenerateContentResponse", - "ChatCompletionsRequest", }, ) @@ -952,28 +951,4 @@ class UsageMetadata(proto.Message): ) -class ChatCompletionsRequest(proto.Message): - r"""Request message for [PredictionService.ChatCompletions] - - Attributes: - endpoint (str): - Required. The name of the Endpoint requested to serve the - prediction. Format: - ``projects/{project}/locations/{location}/endpoints/openapi`` - http_body (google.api.httpbody_pb2.HttpBody): - Optional. The prediction input. Supports HTTP - headers and arbitrary data payload. - """ - - endpoint: str = proto.Field( - proto.STRING, - number=1, - ) - http_body: httpbody_pb2.HttpBody = proto.Field( - proto.MESSAGE, - number=2, - message=httpbody_pb2.HttpBody, - ) - - __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/publisher_model.py b/google/cloud/aiplatform_v1beta1/types/publisher_model.py index 12cb188b36..3ec6a22685 100644 --- a/google/cloud/aiplatform_v1beta1/types/publisher_model.py +++ b/google/cloud/aiplatform_v1beta1/types/publisher_model.py @@ -451,6 +451,11 @@ class Deploy(proto.Message): Optional. The path to the directory containing the Model artifact and any of its supporting files. + deploy_task_name (str): + Optional. The name of the deploy task (e.g., + "text to image generation"). + + This field is a member of `oneof`_ ``_deploy_task_name``. title (str): Required. The title of the regional resource reference. @@ -494,6 +499,11 @@ class Deploy(proto.Message): proto.STRING, number=4, ) + deploy_task_name: str = proto.Field( + proto.STRING, + number=10, + optional=True, + ) title: str = proto.Field( proto.STRING, number=8, diff --git a/google/cloud/aiplatform_v1beta1/types/schedule.py b/google/cloud/aiplatform_v1beta1/types/schedule.py index 93a4c6a497..02dab2e352 100644 --- a/google/cloud/aiplatform_v1beta1/types/schedule.py +++ b/google/cloud/aiplatform_v1beta1/types/schedule.py @@ -20,6 +20,7 @@ import proto # type: ignore from google.cloud.aiplatform_v1beta1.types import model_monitoring_service +from google.cloud.aiplatform_v1beta1.types import notebook_service from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.protobuf import timestamp_pb2 # type: ignore @@ -66,6 +67,10 @@ class Schedule(proto.Message): Request for [ModelMonitoringService.CreateModelMonitoringJob][google.cloud.aiplatform.v1beta1.ModelMonitoringService.CreateModelMonitoringJob]. + This field is a member of `oneof`_ ``request``. + create_notebook_execution_job_request (google.cloud.aiplatform_v1beta1.types.CreateNotebookExecutionJobRequest): + Request for [NotebookService.CreateNotebookExecutionJob][]. + This field is a member of `oneof`_ ``request``. name (str): Immutable. The resource name of the Schedule. @@ -202,6 +207,12 @@ class RunResponse(proto.Message): oneof="request", message=model_monitoring_service.CreateModelMonitoringJobRequest, ) + create_notebook_execution_job_request: notebook_service.CreateNotebookExecutionJobRequest = proto.Field( + proto.MESSAGE, + number=20, + oneof="request", + message=notebook_service.CreateNotebookExecutionJobRequest, + ) name: str = proto.Field( proto.STRING, number=1, diff --git a/google/cloud/aiplatform_v1beta1/types/tool.py b/google/cloud/aiplatform_v1beta1/types/tool.py index 26b5498661..36786f3658 100644 --- a/google/cloud/aiplatform_v1beta1/types/tool.py +++ b/google/cloud/aiplatform_v1beta1/types/tool.py @@ -356,27 +356,65 @@ class VertexRagStore(proto.Message): Attributes: rag_corpora (MutableSequence[str]): - Required. Vertex RAG Store corpus resource name: - ``projects/{project}/locations/{location}/ragCorpora/{ragCorpus}`` - Currently only one corpus is allowed. In the future we may - open up multiple corpora support. However, they should be - from the same project and location. + Optional. Deprecated. Please use rag_resources instead. + rag_resources (MutableSequence[google.cloud.aiplatform_v1beta1.types.VertexRagStore.RagResource]): + Optional. The representation of the rag + source. It can be used to specify corpus only or + ragfiles. Currently only support one corpus or + multiple files from one corpus. In the future we + may open up multiple corpora support. similarity_top_k (int): Optional. Number of top k results to return from the selected corpora. This field is a member of `oneof`_ ``_similarity_top_k``. + vector_distance_threshold (float): + Optional. Only return results with vector + distance smaller than the threshold. + + This field is a member of `oneof`_ ``_vector_distance_threshold``. """ + class RagResource(proto.Message): + r"""The definition of the Rag resource. + + Attributes: + rag_corpus (str): + Optional. RagCorpora resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + rag_file_ids (MutableSequence[str]): + Optional. rag_file_id. The files should be in the same + rag_corpus set in rag_corpus field. + """ + + rag_corpus: str = proto.Field( + proto.STRING, + number=1, + ) + rag_file_ids: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=2, + ) + rag_corpora: MutableSequence[str] = proto.RepeatedField( proto.STRING, number=1, ) + rag_resources: MutableSequence[RagResource] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=RagResource, + ) similarity_top_k: int = proto.Field( proto.INT32, number=2, optional=True, ) + vector_distance_threshold: float = proto.Field( + proto.DOUBLE, + number=3, + optional=True, + ) class VertexAISearch(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py index d0947308c2..39a98a7209 100644 --- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py +++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py @@ -96,7 +96,7 @@ class RagFile(proto.Message): gcs_source (google.cloud.aiplatform_v1beta1.types.GcsSource): Output only. Google Cloud Storage location of the RagFile. It does not support wildcards in - the GCS uri for now. + the Cloud Storage uri for now. This field is a member of `oneof`_ ``rag_file_source``. google_drive_source (google.cloud.aiplatform_v1beta1.types.GoogleDriveSource): diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py index 2857fab341..5b96a8bca1 100644 --- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py +++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data_service.py @@ -279,12 +279,26 @@ class ImportRagFilesResponse(proto.Message): imported_rag_files_count (int): The number of RagFiles that had been imported into the RagCorpus. + failed_rag_files_count (int): + The number of RagFiles that had failed while + importing into the RagCorpus. + skipped_rag_files_count (int): + The number of RagFiles that was skipped while + importing into the RagCorpus. """ imported_rag_files_count: int = proto.Field( proto.INT64, number=1, ) + failed_rag_files_count: int = proto.Field( + proto.INT64, + number=2, + ) + skipped_rag_files_count: int = proto.Field( + proto.INT64, + number=3, + ) class GetRagFileRequest(proto.Message): @@ -408,6 +422,9 @@ class ImportRagFilesOperationMetadata(proto.Message): rag_corpus_id (int): The resource ID of RagCorpus that this operation is executed on. + import_rag_files_config (google.cloud.aiplatform_v1beta1.types.ImportRagFilesConfig): + Output only. The config that was passed in + the ImportRagFilesRequest. """ generic_metadata: operation.GenericOperationMetadata = proto.Field( @@ -419,6 +436,11 @@ class ImportRagFilesOperationMetadata(proto.Message): proto.INT64, number=2, ) + import_rag_files_config: vertex_rag_data.ImportRagFilesConfig = proto.Field( + proto.MESSAGE, + number=3, + message=vertex_rag_data.ImportRagFilesConfig, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py index 6ea53b31dd..e33f30e95c 100644 --- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py +++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_service.py @@ -81,19 +81,62 @@ class RetrieveContextsRequest(proto.Message): class VertexRagStore(proto.Message): r"""The data source for Vertex RagStore. + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + Attributes: rag_corpora (MutableSequence[str]): - Required. RagCorpora resource name. Format: - ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` - Currently only one corpus is allowed. In the future we may - open up multiple corpora support. However, they should be - from the same project and location. + Optional. Deprecated. Please use rag_resources to specify + the data source. + rag_resources (MutableSequence[google.cloud.aiplatform_v1beta1.types.RetrieveContextsRequest.VertexRagStore.RagResource]): + Optional. The representation of the rag + source. It can be used to specify corpus only or + ragfiles. Currently only support one corpus or + multiple files from one corpus. In the future we + may open up multiple corpora support. + vector_distance_threshold (float): + Optional. Only return contexts with vector + distance smaller than the threshold. + + This field is a member of `oneof`_ ``_vector_distance_threshold``. """ + class RagResource(proto.Message): + r"""The definition of the Rag resource. + + Attributes: + rag_corpus (str): + Optional. RagCorpora resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + rag_file_ids (MutableSequence[str]): + Optional. rag_file_id. The files should be in the same + rag_corpus set in rag_corpus field. + """ + + rag_corpus: str = proto.Field( + proto.STRING, + number=1, + ) + rag_file_ids: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=2, + ) + rag_corpora: MutableSequence[str] = proto.RepeatedField( proto.STRING, number=1, ) + rag_resources: MutableSequence[ + "RetrieveContextsRequest.VertexRagStore.RagResource" + ] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="RetrieveContextsRequest.VertexRagStore.RagResource", + ) + vector_distance_threshold: float = proto.Field( + proto.DOUBLE, + number=2, + optional=True, + ) vertex_rag_store: VertexRagStore = proto.Field( proto.MESSAGE, diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_dataset_service_update_dataset_version_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_dataset_service_update_dataset_version_async.py new file mode 100644 index 0000000000..3411205e7b --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_dataset_service_update_dataset_version_async.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateDatasetVersion +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_DatasetService_UpdateDatasetVersion_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +async def sample_update_dataset_version(): + # Create a client + client = aiplatform_v1beta1.DatasetServiceAsyncClient() + + # Initialize request argument(s) + dataset_version = aiplatform_v1beta1.DatasetVersion() + dataset_version.metadata.null_value = "NULL_VALUE" + + request = aiplatform_v1beta1.UpdateDatasetVersionRequest( + dataset_version=dataset_version, + ) + + # Make the request + response = await client.update_dataset_version(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_DatasetService_UpdateDatasetVersion_async] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_dataset_service_update_dataset_version_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_dataset_service_update_dataset_version_sync.py new file mode 100644 index 0000000000..5442d9a3a0 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_dataset_service_update_dataset_version_sync.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateDatasetVersion +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_DatasetService_UpdateDatasetVersion_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +def sample_update_dataset_version(): + # Create a client + client = aiplatform_v1beta1.DatasetServiceClient() + + # Initialize request argument(s) + dataset_version = aiplatform_v1beta1.DatasetVersion() + dataset_version.metadata.null_value = "NULL_VALUE" + + request = aiplatform_v1beta1.UpdateDatasetVersionRequest( + dataset_version=dataset_version, + ) + + # Make the request + response = client.update_dataset_version(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_DatasetService_UpdateDatasetVersion_sync] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_async.py new file mode 100644 index 0000000000..9a5eb42113 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteNotebookExecutionJob +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_NotebookService_DeleteNotebookExecutionJob_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +async def sample_delete_notebook_execution_job(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.DeleteNotebookExecutionJobRequest( + name="name_value", + ) + + # Make the request + operation = client.delete_notebook_execution_job(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_NotebookService_DeleteNotebookExecutionJob_async] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_sync.py new file mode 100644 index 0000000000..2f55b15f4d --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteNotebookExecutionJob +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_NotebookService_DeleteNotebookExecutionJob_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +def sample_delete_notebook_execution_job(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.DeleteNotebookExecutionJobRequest( + name="name_value", + ) + + # Make the request + operation = client.delete_notebook_execution_job(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_NotebookService_DeleteNotebookExecutionJob_sync] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_async.py new file mode 100644 index 0000000000..e4f1f030c2 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetNotebookExecutionJob +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_NotebookService_GetNotebookExecutionJob_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +async def sample_get_notebook_execution_job(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.GetNotebookExecutionJobRequest( + name="name_value", + ) + + # Make the request + response = await client.get_notebook_execution_job(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_NotebookService_GetNotebookExecutionJob_async] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_sync.py new file mode 100644 index 0000000000..66d2da36fb --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetNotebookExecutionJob +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_NotebookService_GetNotebookExecutionJob_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +def sample_get_notebook_execution_job(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.GetNotebookExecutionJobRequest( + name="name_value", + ) + + # Make the request + response = client.get_notebook_execution_job(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1beta1_generated_NotebookService_GetNotebookExecutionJob_sync] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_async.py new file mode 100644 index 0000000000..24851899e0 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListNotebookExecutionJobs +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_NotebookService_ListNotebookExecutionJobs_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +async def sample_list_notebook_execution_jobs(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.ListNotebookExecutionJobsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_notebook_execution_jobs(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END aiplatform_v1beta1_generated_NotebookService_ListNotebookExecutionJobs_async] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_sync.py new file mode 100644 index 0000000000..3cb04660fe --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListNotebookExecutionJobs +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_NotebookService_ListNotebookExecutionJobs_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +def sample_list_notebook_execution_jobs(): + # Create a client + client = aiplatform_v1beta1.NotebookServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.ListNotebookExecutionJobsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_notebook_execution_jobs(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END aiplatform_v1beta1_generated_NotebookService_ListNotebookExecutionJobs_sync] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_service_retrieve_contexts_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_service_retrieve_contexts_async.py index 7505fd1e36..c89939d1eb 100644 --- a/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_service_retrieve_contexts_async.py +++ b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_service_retrieve_contexts_async.py @@ -39,14 +39,10 @@ async def sample_retrieve_contexts(): client = aiplatform_v1beta1.VertexRagServiceAsyncClient() # Initialize request argument(s) - vertex_rag_store = aiplatform_v1beta1.VertexRagStore() - vertex_rag_store.rag_corpora = ['rag_corpora_value1', 'rag_corpora_value2'] - query = aiplatform_v1beta1.RagQuery() query.text = "text_value" request = aiplatform_v1beta1.RetrieveContextsRequest( - vertex_rag_store=vertex_rag_store, parent="parent_value", query=query, ) diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_service_retrieve_contexts_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_service_retrieve_contexts_sync.py index 987da38bf0..4946f58159 100644 --- a/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_service_retrieve_contexts_sync.py +++ b/samples/generated_samples/aiplatform_v1beta1_generated_vertex_rag_service_retrieve_contexts_sync.py @@ -39,14 +39,10 @@ def sample_retrieve_contexts(): client = aiplatform_v1beta1.VertexRagServiceClient() # Initialize request argument(s) - vertex_rag_store = aiplatform_v1beta1.VertexRagStore() - vertex_rag_store.rag_corpora = ['rag_corpora_value1', 'rag_corpora_value2'] - query = aiplatform_v1beta1.RagQuery() query.text = "text_value" request = aiplatform_v1beta1.RetrieveContextsRequest( - vertex_rag_store=vertex_rag_store, parent="parent_value", query=query, ) diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 39acefe786..a266a24b4b 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.50.0" + "version": "0.1.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 79deba42e4..0ce1d4d7c4 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.50.0" + "version": "0.1.0" }, "snippets": [ { @@ -2772,6 +2772,175 @@ ], "title": "aiplatform_v1beta1_generated_dataset_service_search_data_items_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.DatasetServiceAsyncClient", + "shortName": "DatasetServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.DatasetServiceAsyncClient.update_dataset_version", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.DatasetService.UpdateDatasetVersion", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.DatasetService", + "shortName": "DatasetService" + }, + "shortName": "UpdateDatasetVersion" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.UpdateDatasetVersionRequest" + }, + { + "name": "dataset_version", + "type": "google.cloud.aiplatform_v1beta1.types.DatasetVersion" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1beta1.types.DatasetVersion", + "shortName": "update_dataset_version" + }, + "description": "Sample for UpdateDatasetVersion", + "file": "aiplatform_v1beta1_generated_dataset_service_update_dataset_version_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_DatasetService_UpdateDatasetVersion_async", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_dataset_service_update_dataset_version_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.DatasetServiceClient", + "shortName": "DatasetServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.DatasetServiceClient.update_dataset_version", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.DatasetService.UpdateDatasetVersion", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.DatasetService", + "shortName": "DatasetService" + }, + "shortName": "UpdateDatasetVersion" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.UpdateDatasetVersionRequest" + }, + { + "name": "dataset_version", + "type": "google.cloud.aiplatform_v1beta1.types.DatasetVersion" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1beta1.types.DatasetVersion", + "shortName": "update_dataset_version" + }, + "description": "Sample for UpdateDatasetVersion", + "file": "aiplatform_v1beta1_generated_dataset_service_update_dataset_version_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_DatasetService_UpdateDatasetVersion_sync", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_dataset_service_update_dataset_version_sync.py" + }, { "canonical": true, "clientMethod": { @@ -34354,6 +34523,167 @@ ], "title": "aiplatform_v1beta1_generated_notebook_service_create_notebook_runtime_template_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceAsyncClient", + "shortName": "NotebookServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceAsyncClient.delete_notebook_execution_job", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService.DeleteNotebookExecutionJob", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService", + "shortName": "NotebookService" + }, + "shortName": "DeleteNotebookExecutionJob" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.DeleteNotebookExecutionJobRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation_async.AsyncOperation", + "shortName": "delete_notebook_execution_job" + }, + "description": "Sample for DeleteNotebookExecutionJob", + "file": "aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_NotebookService_DeleteNotebookExecutionJob_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient", + "shortName": "NotebookServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient.delete_notebook_execution_job", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService.DeleteNotebookExecutionJob", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService", + "shortName": "NotebookService" + }, + "shortName": "DeleteNotebookExecutionJob" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.DeleteNotebookExecutionJobRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation.Operation", + "shortName": "delete_notebook_execution_job" + }, + "description": "Sample for DeleteNotebookExecutionJob", + "file": "aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_NotebookService_DeleteNotebookExecutionJob_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_notebook_service_delete_notebook_execution_job_sync.py" + }, { "canonical": true, "clientMethod": { @@ -34676,6 +35006,167 @@ ], "title": "aiplatform_v1beta1_generated_notebook_service_delete_notebook_runtime_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceAsyncClient", + "shortName": "NotebookServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceAsyncClient.get_notebook_execution_job", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService.GetNotebookExecutionJob", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService", + "shortName": "NotebookService" + }, + "shortName": "GetNotebookExecutionJob" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.GetNotebookExecutionJobRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1beta1.types.NotebookExecutionJob", + "shortName": "get_notebook_execution_job" + }, + "description": "Sample for GetNotebookExecutionJob", + "file": "aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_NotebookService_GetNotebookExecutionJob_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient", + "shortName": "NotebookServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient.get_notebook_execution_job", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService.GetNotebookExecutionJob", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService", + "shortName": "NotebookService" + }, + "shortName": "GetNotebookExecutionJob" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.GetNotebookExecutionJobRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1beta1.types.NotebookExecutionJob", + "shortName": "get_notebook_execution_job" + }, + "description": "Sample for GetNotebookExecutionJob", + "file": "aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_NotebookService_GetNotebookExecutionJob_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_notebook_service_get_notebook_execution_job_sync.py" + }, { "canonical": true, "clientMethod": { @@ -34916,7 +35407,168 @@ "type": "RESPONSE_HANDLING" } ], - "title": "aiplatform_v1beta1_generated_notebook_service_get_notebook_runtime_async.py" + "title": "aiplatform_v1beta1_generated_notebook_service_get_notebook_runtime_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient", + "shortName": "NotebookServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient.get_notebook_runtime", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService.GetNotebookRuntime", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService", + "shortName": "NotebookService" + }, + "shortName": "GetNotebookRuntime" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.GetNotebookRuntimeRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1beta1.types.NotebookRuntime", + "shortName": "get_notebook_runtime" + }, + "description": "Sample for GetNotebookRuntime", + "file": "aiplatform_v1beta1_generated_notebook_service_get_notebook_runtime_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_NotebookService_GetNotebookRuntime_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_notebook_service_get_notebook_runtime_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceAsyncClient", + "shortName": "NotebookServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceAsyncClient.list_notebook_execution_jobs", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService.ListNotebookExecutionJobs", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService", + "shortName": "NotebookService" + }, + "shortName": "ListNotebookExecutionJobs" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1beta1.services.notebook_service.pagers.ListNotebookExecutionJobsAsyncPager", + "shortName": "list_notebook_execution_jobs" + }, + "description": "Sample for ListNotebookExecutionJobs", + "file": "aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_NotebookService_ListNotebookExecutionJobs_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_async.py" }, { "canonical": true, @@ -34925,22 +35577,22 @@ "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient", "shortName": "NotebookServiceClient" }, - "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient.get_notebook_runtime", + "fullName": "google.cloud.aiplatform_v1beta1.NotebookServiceClient.list_notebook_execution_jobs", "method": { - "fullName": "google.cloud.aiplatform.v1beta1.NotebookService.GetNotebookRuntime", + "fullName": "google.cloud.aiplatform.v1beta1.NotebookService.ListNotebookExecutionJobs", "service": { "fullName": "google.cloud.aiplatform.v1beta1.NotebookService", "shortName": "NotebookService" }, - "shortName": "GetNotebookRuntime" + "shortName": "ListNotebookExecutionJobs" }, "parameters": [ { "name": "request", - "type": "google.cloud.aiplatform_v1beta1.types.GetNotebookRuntimeRequest" + "type": "google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsRequest" }, { - "name": "name", + "name": "parent", "type": "str" }, { @@ -34956,22 +35608,22 @@ "type": "Sequence[Tuple[str, str]" } ], - "resultType": "google.cloud.aiplatform_v1beta1.types.NotebookRuntime", - "shortName": "get_notebook_runtime" + "resultType": "google.cloud.aiplatform_v1beta1.services.notebook_service.pagers.ListNotebookExecutionJobsPager", + "shortName": "list_notebook_execution_jobs" }, - "description": "Sample for GetNotebookRuntime", - "file": "aiplatform_v1beta1_generated_notebook_service_get_notebook_runtime_sync.py", + "description": "Sample for ListNotebookExecutionJobs", + "file": "aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_sync.py", "language": "PYTHON", "origin": "API_DEFINITION", - "regionTag": "aiplatform_v1beta1_generated_NotebookService_GetNotebookRuntime_sync", + "regionTag": "aiplatform_v1beta1_generated_NotebookService_ListNotebookExecutionJobs_sync", "segments": [ { - "end": 51, + "end": 52, "start": 27, "type": "FULL" }, { - "end": 51, + "end": 52, "start": 27, "type": "SHORT" }, @@ -34991,12 +35643,12 @@ "type": "REQUEST_EXECUTION" }, { - "end": 52, + "end": 53, "start": 49, "type": "RESPONSE_HANDLING" } ], - "title": "aiplatform_v1beta1_generated_notebook_service_get_notebook_runtime_sync.py" + "title": "aiplatform_v1beta1_generated_notebook_service_list_notebook_execution_jobs_sync.py" }, { "canonical": true, @@ -38592,175 +39244,6 @@ ], "title": "aiplatform_v1beta1_generated_pipeline_service_list_training_pipelines_sync.py" }, - { - "canonical": true, - "clientMethod": { - "async": true, - "client": { - "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceAsyncClient", - "shortName": "PredictionServiceAsyncClient" - }, - "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceAsyncClient.chat_completions", - "method": { - "fullName": "google.cloud.aiplatform.v1beta1.PredictionService.ChatCompletions", - "service": { - "fullName": "google.cloud.aiplatform.v1beta1.PredictionService", - "shortName": "PredictionService" - }, - "shortName": "ChatCompletions" - }, - "parameters": [ - { - "name": "request", - "type": "google.cloud.aiplatform_v1beta1.types.ChatCompletionsRequest" - }, - { - "name": "endpoint", - "type": "str" - }, - { - "name": "http_body", - "type": "google.api.httpbody_pb2.HttpBody" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "Iterable[google.api.httpbody_pb2.HttpBody]", - "shortName": "chat_completions" - }, - "description": "Sample for ChatCompletions", - "file": "aiplatform_v1beta1_generated_prediction_service_chat_completions_async.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "aiplatform_v1beta1_generated_PredictionService_ChatCompletions_async", - "segments": [ - { - "end": 52, - "start": 27, - "type": "FULL" - }, - { - "end": 52, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 45, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 48, - "start": 46, - "type": "REQUEST_EXECUTION" - }, - { - "end": 53, - "start": 49, - "type": "RESPONSE_HANDLING" - } - ], - "title": "aiplatform_v1beta1_generated_prediction_service_chat_completions_async.py" - }, - { - "canonical": true, - "clientMethod": { - "client": { - "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceClient", - "shortName": "PredictionServiceClient" - }, - "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceClient.chat_completions", - "method": { - "fullName": "google.cloud.aiplatform.v1beta1.PredictionService.ChatCompletions", - "service": { - "fullName": "google.cloud.aiplatform.v1beta1.PredictionService", - "shortName": "PredictionService" - }, - "shortName": "ChatCompletions" - }, - "parameters": [ - { - "name": "request", - "type": "google.cloud.aiplatform_v1beta1.types.ChatCompletionsRequest" - }, - { - "name": "endpoint", - "type": "str" - }, - { - "name": "http_body", - "type": "google.api.httpbody_pb2.HttpBody" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "Iterable[google.api.httpbody_pb2.HttpBody]", - "shortName": "chat_completions" - }, - "description": "Sample for ChatCompletions", - "file": "aiplatform_v1beta1_generated_prediction_service_chat_completions_sync.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "aiplatform_v1beta1_generated_PredictionService_ChatCompletions_sync", - "segments": [ - { - "end": 52, - "start": 27, - "type": "FULL" - }, - { - "end": 52, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 45, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 48, - "start": 46, - "type": "REQUEST_EXECUTION" - }, - { - "end": 53, - "start": 49, - "type": "RESPONSE_HANDLING" - } - ], - "title": "aiplatform_v1beta1_generated_prediction_service_chat_completions_sync.py" - }, { "canonical": true, "clientMethod": { @@ -50111,12 +50594,12 @@ "regionTag": "aiplatform_v1beta1_generated_VertexRagService_RetrieveContexts_async", "segments": [ { - "end": 59, + "end": 55, "start": 27, "type": "FULL" }, { - "end": 59, + "end": 55, "start": 27, "type": "SHORT" }, @@ -50126,18 +50609,18 @@ "type": "CLIENT_INITIALIZATION" }, { - "end": 53, + "end": 49, "start": 41, "type": "REQUEST_INITIALIZATION" }, { - "end": 56, - "start": 54, + "end": 52, + "start": 50, "type": "REQUEST_EXECUTION" }, { - "end": 60, - "start": 57, + "end": 56, + "start": 53, "type": "RESPONSE_HANDLING" } ], @@ -50195,12 +50678,12 @@ "regionTag": "aiplatform_v1beta1_generated_VertexRagService_RetrieveContexts_sync", "segments": [ { - "end": 59, + "end": 55, "start": 27, "type": "FULL" }, { - "end": 59, + "end": 55, "start": 27, "type": "SHORT" }, @@ -50210,18 +50693,18 @@ "type": "CLIENT_INITIALIZATION" }, { - "end": 53, + "end": 49, "start": 41, "type": "REQUEST_INITIALIZATION" }, { - "end": 56, - "start": 54, + "end": 52, + "start": 50, "type": "REQUEST_EXECUTION" }, { - "end": 60, - "start": 57, + "end": 56, + "start": 53, "type": "RESPONSE_HANDLING" } ], diff --git a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py index cb0687adf1..19a9960915 100644 --- a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py @@ -1188,6 +1188,9 @@ def test_create_dataset_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1211,6 +1214,9 @@ def test_create_dataset_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1219,6 +1225,45 @@ def test_create_dataset_non_empty_request_with_auto_populated_field(): ) +def test_create_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_dataset] = mock_rpc + request = {} + client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1240,6 +1285,56 @@ async def test_create_dataset_empty_call_async(): assert args[0] == dataset_service.CreateDatasetRequest() +@pytest.mark.asyncio +async def test_create_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_dataset + ] = mock_object + + request = {} + await client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest @@ -1487,6 +1582,9 @@ def test_get_dataset_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1510,6 +1608,9 @@ def test_get_dataset_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1518,6 +1619,41 @@ def test_get_dataset_non_empty_request_with_auto_populated_field(): ) +def test_get_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_dataset] = mock_rpc + request = {} + client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1547,6 +1683,52 @@ async def test_get_dataset_empty_call_async(): assert args[0] == dataset_service.GetDatasetRequest() +@pytest.mark.asyncio +async def test_get_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_dataset + ] = mock_object + + request = {} + await client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest @@ -1795,6 +1977,9 @@ def test_update_dataset_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1816,12 +2001,50 @@ def test_update_dataset_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.UpdateDatasetRequest() +def test_update_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_dataset] = mock_rpc + request = {} + client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1851,6 +2074,52 @@ async def test_update_dataset_empty_call_async(): assert args[0] == dataset_service.UpdateDatasetRequest() +@pytest.mark.asyncio +async def test_update_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_dataset + ] = mock_object + + request = {} + await client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest @@ -2097,6 +2366,9 @@ def test_list_datasets_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_datasets() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2123,6 +2395,9 @@ def test_list_datasets_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_datasets(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2134,6 +2409,41 @@ def test_list_datasets_non_empty_request_with_auto_populated_field(): ) +def test_list_datasets_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_datasets in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_datasets] = mock_rpc + request = {} + client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_datasets(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_datasets_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2157,6 +2467,52 @@ async def test_list_datasets_empty_call_async(): assert args[0] == dataset_service.ListDatasetsRequest() +@pytest.mark.asyncio +async def test_list_datasets_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_datasets + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_datasets + ] = mock_object + + request = {} + await client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_datasets(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_datasets_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest @@ -2572,6 +2928,9 @@ def test_delete_dataset_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2595,6 +2954,9 @@ def test_delete_dataset_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2603,6 +2965,45 @@ def test_delete_dataset_non_empty_request_with_auto_populated_field(): ) +def test_delete_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_dataset] = mock_rpc + request = {} + client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2624,6 +3025,56 @@ async def test_delete_dataset_empty_call_async(): assert args[0] == dataset_service.DeleteDatasetRequest() +@pytest.mark.asyncio +async def test_delete_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_dataset + ] = mock_object + + request = {} + await client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest @@ -2846,6 +3297,9 @@ def test_import_data_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2869,6 +3323,9 @@ def test_import_data_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2877,6 +3334,45 @@ def test_import_data_non_empty_request_with_auto_populated_field(): ) +def test_import_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_data in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.import_data] = mock_rpc + request = {} + client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2898,6 +3394,56 @@ async def test_import_data_empty_call_async(): assert args[0] == dataset_service.ImportDataRequest() +@pytest.mark.asyncio +async def test_import_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_data + ] = mock_object + + request = {} + await client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.import_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_data_async( transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest @@ -3142,6 +3688,9 @@ def test_export_data_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3165,6 +3714,9 @@ def test_export_data_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3173,6 +3725,45 @@ def test_export_data_non_empty_request_with_auto_populated_field(): ) +def test_export_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_data in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_data] = mock_rpc + request = {} + client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3194,6 +3785,56 @@ async def test_export_data_empty_call_async(): assert args[0] == dataset_service.ExportDataRequest() +@pytest.mark.asyncio +async def test_export_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_data + ] = mock_object + + request = {} + await client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.export_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_export_data_async( transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest @@ -3454,6 +4095,9 @@ def test_create_dataset_version_empty_call(): with mock.patch.object( type(client.transport.create_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3479,6 +4123,9 @@ def test_create_dataset_version_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3487,6 +4134,50 @@ def test_create_dataset_version_non_empty_request_with_auto_populated_field(): ) +def test_create_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_dataset_version + ] = mock_rpc + request = {} + client.create_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_dataset_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3510,6 +4201,56 @@ async def test_create_dataset_version_empty_call_async(): assert args[0] == dataset_service.CreateDatasetVersionRequest() +@pytest.mark.asyncio +async def test_create_dataset_version_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_dataset_version + ] = mock_object + + request = {} + await client.create_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_dataset_version_async( transport: str = "grpc_asyncio", @@ -3757,6 +4498,9 @@ def test_delete_dataset_version_empty_call(): with mock.patch.object( type(client.transport.delete_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3782,6 +4526,9 @@ def test_delete_dataset_version_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3790,6 +4537,50 @@ def test_delete_dataset_version_non_empty_request_with_auto_populated_field(): ) +def test_delete_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_dataset_version + ] = mock_rpc + request = {} + client.delete_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_dataset_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3813,6 +4604,56 @@ async def test_delete_dataset_version_empty_call_async(): assert args[0] == dataset_service.DeleteDatasetVersionRequest() +@pytest.mark.asyncio +async def test_delete_dataset_version_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_dataset_version + ] = mock_object + + request = {} + await client.delete_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_dataset_version_async( transport: str = "grpc_asyncio", @@ -4059,6 +4900,9 @@ def test_get_dataset_version_empty_call(): with mock.patch.object( type(client.transport.get_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4084,6 +4928,9 @@ def test_get_dataset_version_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4092,6 +4939,45 @@ def test_get_dataset_version_non_empty_request_with_auto_populated_field(): ) +def test_get_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_dataset_version in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_dataset_version + ] = mock_rpc + request = {} + client.get_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_dataset_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4120,6 +5006,52 @@ async def test_get_dataset_version_empty_call_async(): assert args[0] == dataset_service.GetDatasetVersionRequest() +@pytest.mark.asyncio +async def test_get_dataset_version_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_dataset_version + ] = mock_object + + request = {} + await client.get_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_dataset_version_async( transport: str = "grpc_asyncio", @@ -4369,6 +5301,9 @@ def test_list_dataset_versions_empty_call(): with mock.patch.object( type(client.transport.list_dataset_versions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_dataset_versions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4397,6 +5332,9 @@ def test_list_dataset_versions_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_dataset_versions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_dataset_versions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4408,6 +5346,46 @@ def test_list_dataset_versions_non_empty_request_with_auto_populated_field(): ) +def test_list_dataset_versions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_dataset_versions + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_dataset_versions + ] = mock_rpc + request = {} + client.list_dataset_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_dataset_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_dataset_versions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4433,6 +5411,52 @@ async def test_list_dataset_versions_empty_call_async(): assert args[0] == dataset_service.ListDatasetVersionsRequest() +@pytest.mark.asyncio +async def test_list_dataset_versions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_dataset_versions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_dataset_versions + ] = mock_object + + request = {} + await client.list_dataset_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_dataset_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_dataset_versions_async( transport: str = "grpc_asyncio", @@ -4871,6 +5895,9 @@ def test_restore_dataset_version_empty_call(): with mock.patch.object( type(client.transport.restore_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.restore_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4896,6 +5923,9 @@ def test_restore_dataset_version_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.restore_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.restore_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4904,6 +5934,50 @@ def test_restore_dataset_version_non_empty_request_with_auto_populated_field(): ) +def test_restore_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.restore_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.restore_dataset_version + ] = mock_rpc + request = {} + client.restore_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.restore_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_restore_dataset_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4927,6 +6001,56 @@ async def test_restore_dataset_version_empty_call_async(): assert args[0] == dataset_service.RestoreDatasetVersionRequest() +@pytest.mark.asyncio +async def test_restore_dataset_version_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.restore_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.restore_dataset_version + ] = mock_object + + request = {} + await client.restore_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.restore_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_restore_dataset_version_async( transport: str = "grpc_asyncio", @@ -5163,6 +6287,9 @@ def test_list_data_items_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_data_items() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5189,6 +6316,9 @@ def test_list_data_items_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_data_items(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5200,6 +6330,41 @@ def test_list_data_items_non_empty_request_with_auto_populated_field(): ) +def test_list_data_items_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_data_items in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_data_items] = mock_rpc + request = {} + client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_data_items_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5223,6 +6388,52 @@ async def test_list_data_items_empty_call_async(): assert args[0] == dataset_service.ListDataItemsRequest() +@pytest.mark.asyncio +async def test_list_data_items_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_data_items + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_data_items + ] = mock_object + + request = {} + await client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_data_items_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest @@ -5645,6 +6856,9 @@ def test_search_data_items_empty_call(): with mock.patch.object( type(client.transport.search_data_items), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_data_items() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5677,6 +6891,9 @@ def test_search_data_items_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.search_data_items), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_data_items(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5692,6 +6909,43 @@ def test_search_data_items_non_empty_request_with_auto_populated_field(): ) +def test_search_data_items_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.search_data_items in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_data_items + ] = mock_rpc + request = {} + client.search_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_data_items_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5717,6 +6971,52 @@ async def test_search_data_items_empty_call_async(): assert args[0] == dataset_service.SearchDataItemsRequest() +@pytest.mark.asyncio +async def test_search_data_items_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_data_items + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_data_items + ] = mock_object + + request = {} + await client.search_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_data_items_async( transport: str = "grpc_asyncio", request_type=dataset_service.SearchDataItemsRequest @@ -6071,6 +7371,9 @@ def test_list_saved_queries_empty_call(): with mock.patch.object( type(client.transport.list_saved_queries), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_saved_queries() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6099,6 +7402,9 @@ def test_list_saved_queries_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_saved_queries), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_saved_queries(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6110,6 +7416,45 @@ def test_list_saved_queries_non_empty_request_with_auto_populated_field(): ) +def test_list_saved_queries_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_saved_queries in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_saved_queries + ] = mock_rpc + request = {} + client.list_saved_queries(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_saved_queries(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_saved_queries_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6135,6 +7480,52 @@ async def test_list_saved_queries_empty_call_async(): assert args[0] == dataset_service.ListSavedQueriesRequest() +@pytest.mark.asyncio +async def test_list_saved_queries_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_saved_queries + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_saved_queries + ] = mock_object + + request = {} + await client.list_saved_queries(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_saved_queries(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_saved_queries_async( transport: str = "grpc_asyncio", @@ -6573,6 +7964,9 @@ def test_delete_saved_query_empty_call(): with mock.patch.object( type(client.transport.delete_saved_query), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_saved_query() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6598,6 +7992,9 @@ def test_delete_saved_query_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_saved_query), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_saved_query(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6606,6 +8003,49 @@ def test_delete_saved_query_non_empty_request_with_auto_populated_field(): ) +def test_delete_saved_query_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_saved_query in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_saved_query + ] = mock_rpc + request = {} + client.delete_saved_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_saved_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_saved_query_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6629,6 +8069,56 @@ async def test_delete_saved_query_empty_call_async(): assert args[0] == dataset_service.DeleteSavedQueryRequest() +@pytest.mark.asyncio +async def test_delete_saved_query_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_saved_query + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_saved_query + ] = mock_object + + request = {} + await client.delete_saved_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_saved_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_saved_query_async( transport: str = "grpc_asyncio", @@ -6873,6 +8363,9 @@ def test_get_annotation_spec_empty_call(): with mock.patch.object( type(client.transport.get_annotation_spec), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_annotation_spec() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6898,6 +8391,9 @@ def test_get_annotation_spec_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_annotation_spec), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_annotation_spec(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6906,6 +8402,45 @@ def test_get_annotation_spec_non_empty_request_with_auto_populated_field(): ) +def test_get_annotation_spec_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_annotation_spec in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_annotation_spec + ] = mock_rpc + request = {} + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_annotation_spec(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_annotation_spec_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6933,6 +8468,52 @@ async def test_get_annotation_spec_empty_call_async(): assert args[0] == dataset_service.GetAnnotationSpecRequest() +@pytest.mark.asyncio +async def test_get_annotation_spec_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_annotation_spec + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_annotation_spec + ] = mock_object + + request = {} + await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_annotation_spec(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_annotation_spec_async( transport: str = "grpc_asyncio", @@ -7176,6 +8757,9 @@ def test_list_annotations_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_annotations() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7202,6 +8786,9 @@ def test_list_annotations_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_annotations(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7213,6 +8800,43 @@ def test_list_annotations_non_empty_request_with_auto_populated_field(): ) +def test_list_annotations_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_annotations in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_annotations + ] = mock_rpc + request = {} + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_annotations_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7236,6 +8860,52 @@ async def test_list_annotations_empty_call_async(): assert args[0] == dataset_service.ListAnnotationsRequest() +@pytest.mark.asyncio +async def test_list_annotations_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_annotations + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_annotations + ] = mock_object + + request = {} + await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_annotations_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest @@ -7745,6 +9415,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_dataset] = mock_rpc + + request = {} + client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_dataset_rest_required_fields( request_type=dataset_service.CreateDatasetRequest, ): @@ -8031,6 +9741,42 @@ def test_get_dataset_rest(request_type): assert response.metadata_artifact == "metadata_artifact_value" +def test_get_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_dataset] = mock_rpc + + request = {} + client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_dataset_rest_required_fields( request_type=dataset_service.GetDatasetRequest, ): @@ -8413,6 +10159,42 @@ def get_message_fields(field): assert response.metadata_artifact == "metadata_artifact_value" +def test_update_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_dataset] = mock_rpc + + request = {} + client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_dataset_rest_required_fields( request_type=dataset_service.UpdateDatasetRequest, ): @@ -8681,13 +10463,49 @@ def test_list_datasets_rest(request_type): return_value = dataset_service.ListDatasetsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.list_datasets(request) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_datasets(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatasetsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_datasets_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_datasets in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_datasets] = mock_rpc + + request = {} + client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == "next_page_token_value" + client.list_datasets(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_list_datasets_rest_required_fields( @@ -9029,6 +10847,46 @@ def test_delete_dataset_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_dataset] = mock_rpc + + request = {} + client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_dataset_rest_required_fields( request_type=dataset_service.DeleteDatasetRequest, ): @@ -9287,6 +11145,46 @@ def test_import_data_rest(request_type): assert response.operation.name == "operations/spam" +def test_import_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_data in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.import_data] = mock_rpc + + request = {} + client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_import_data_rest_required_fields( request_type=dataset_service.ImportDataRequest, ): @@ -9561,6 +11459,46 @@ def test_export_data_rest(request_type): assert response.operation.name == "operations/spam" +def test_export_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_data in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_data] = mock_rpc + + request = {} + client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_export_data_rest_required_fields( request_type=dataset_service.ExportDataRequest, ): @@ -9924,6 +11862,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_dataset_version + ] = mock_rpc + + request = {} + client.create_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_dataset_version_rest_required_fields( request_type=dataset_service.CreateDatasetVersionRequest, ): @@ -10198,6 +12181,51 @@ def test_delete_dataset_version_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_dataset_version + ] = mock_rpc + + request = {} + client.delete_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_dataset_version_rest_required_fields( request_type=dataset_service.DeleteDatasetVersionRequest, ): @@ -10474,6 +12502,46 @@ def test_get_dataset_version_rest(request_type): assert response.display_name == "display_name_value" +def test_get_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_dataset_version in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_dataset_version + ] = mock_rpc + + request = {} + client.get_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_dataset_version_rest_required_fields( request_type=dataset_service.GetDatasetVersionRequest, ): @@ -10747,6 +12815,47 @@ def test_list_dataset_versions_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_dataset_versions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_dataset_versions + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_dataset_versions + ] = mock_rpc + + request = {} + client.list_dataset_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_dataset_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_dataset_versions_rest_required_fields( request_type=dataset_service.ListDatasetVersionsRequest, ): @@ -11093,6 +13202,51 @@ def test_restore_dataset_version_rest(request_type): assert response.operation.name == "operations/spam" +def test_restore_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.restore_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.restore_dataset_version + ] = mock_rpc + + request = {} + client.restore_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.restore_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_restore_dataset_version_rest_required_fields( request_type=dataset_service.RestoreDatasetVersionRequest, ): @@ -11361,6 +13515,42 @@ def test_list_data_items_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_data_items_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_data_items in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_data_items] = mock_rpc + + request = {} + client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_data_items_rest_required_fields( request_type=dataset_service.ListDataItemsRequest, ): @@ -11710,6 +13900,44 @@ def test_search_data_items_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_data_items_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.search_data_items in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_data_items + ] = mock_rpc + + request = {} + client.search_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_data_items_rest_required_fields( request_type=dataset_service.SearchDataItemsRequest, ): @@ -12014,6 +14242,46 @@ def test_list_saved_queries_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_saved_queries_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_saved_queries in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_saved_queries + ] = mock_rpc + + request = {} + client.list_saved_queries(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_saved_queries(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_saved_queries_rest_required_fields( request_type=dataset_service.ListSavedQueriesRequest, ): @@ -12360,6 +14628,50 @@ def test_delete_saved_query_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_saved_query_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_saved_query in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_saved_query + ] = mock_rpc + + request = {} + client.delete_saved_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_saved_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_saved_query_rest_required_fields( request_type=dataset_service.DeleteSavedQueryRequest, ): @@ -12634,6 +14946,46 @@ def test_get_annotation_spec_rest(request_type): assert response.etag == "etag_value" +def test_get_annotation_spec_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_annotation_spec in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_annotation_spec + ] = mock_rpc + + request = {} + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_annotation_spec(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_annotation_spec_rest_required_fields( request_type=dataset_service.GetAnnotationSpecRequest, ): @@ -12909,6 +15261,44 @@ def test_list_annotations_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_annotations_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_annotations in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_annotations + ] = mock_rpc + + request = {} + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_annotations_rest_required_fields( request_type=dataset_service.ListAnnotationsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_deployment_resource_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_deployment_resource_pool_service.py index 0e352734dc..3788dc0ba6 100644 --- a/tests/unit/gapic/aiplatform_v1/test_deployment_resource_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_deployment_resource_pool_service.py @@ -1285,6 +1285,9 @@ def test_create_deployment_resource_pool_empty_call(): with mock.patch.object( type(client.transport.create_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_deployment_resource_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1314,6 +1317,9 @@ def test_create_deployment_resource_pool_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.create_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_deployment_resource_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1325,6 +1331,50 @@ def test_create_deployment_resource_pool_non_empty_request_with_auto_populated_f ) +def test_create_deployment_resource_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_deployment_resource_pool + ] = mock_rpc + request = {} + client.create_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_deployment_resource_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1351,6 +1401,56 @@ async def test_create_deployment_resource_pool_empty_call_async(): ) +@pytest.mark.asyncio +async def test_create_deployment_resource_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_deployment_resource_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_deployment_resource_pool + ] = mock_object + + request = {} + await client.create_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_deployment_resource_pool_async( transport: str = "grpc_asyncio", @@ -1623,6 +1723,9 @@ def test_get_deployment_resource_pool_empty_call(): with mock.patch.object( type(client.transport.get_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_deployment_resource_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1651,6 +1754,9 @@ def test_get_deployment_resource_pool_non_empty_request_with_auto_populated_fiel with mock.patch.object( type(client.transport.get_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_deployment_resource_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1661,6 +1767,46 @@ def test_get_deployment_resource_pool_non_empty_request_with_auto_populated_fiel ) +def test_get_deployment_resource_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_deployment_resource_pool + ] = mock_rpc + request = {} + client.get_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_deployment_resource_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1689,6 +1835,52 @@ async def test_get_deployment_resource_pool_empty_call_async(): ) +@pytest.mark.asyncio +async def test_get_deployment_resource_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_deployment_resource_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_deployment_resource_pool + ] = mock_object + + request = {} + await client.get_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_deployment_resource_pool_async( transport: str = "grpc_asyncio", @@ -1934,6 +2126,9 @@ def test_list_deployment_resource_pools_empty_call(): with mock.patch.object( type(client.transport.list_deployment_resource_pools), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_deployment_resource_pools() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1963,6 +2158,9 @@ def test_list_deployment_resource_pools_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.list_deployment_resource_pools), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_deployment_resource_pools(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1974,6 +2172,46 @@ def test_list_deployment_resource_pools_non_empty_request_with_auto_populated_fi ) +def test_list_deployment_resource_pools_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_deployment_resource_pools + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_deployment_resource_pools + ] = mock_rpc + request = {} + client.list_deployment_resource_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_deployment_resource_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_deployment_resource_pools_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2002,6 +2240,52 @@ async def test_list_deployment_resource_pools_empty_call_async(): ) +@pytest.mark.asyncio +async def test_list_deployment_resource_pools_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_deployment_resource_pools + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_deployment_resource_pools + ] = mock_object + + request = {} + await client.list_deployment_resource_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_deployment_resource_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_deployment_resource_pools_async( transport: str = "grpc_asyncio", @@ -2452,6 +2736,9 @@ def test_delete_deployment_resource_pool_empty_call(): with mock.patch.object( type(client.transport.delete_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_deployment_resource_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2480,6 +2767,9 @@ def test_delete_deployment_resource_pool_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.delete_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_deployment_resource_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2490,6 +2780,50 @@ def test_delete_deployment_resource_pool_non_empty_request_with_auto_populated_f ) +def test_delete_deployment_resource_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_deployment_resource_pool + ] = mock_rpc + request = {} + client.delete_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_deployment_resource_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2516,6 +2850,56 @@ async def test_delete_deployment_resource_pool_empty_call_async(): ) +@pytest.mark.asyncio +async def test_delete_deployment_resource_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_deployment_resource_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_deployment_resource_pool + ] = mock_object + + request = {} + await client.delete_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_deployment_resource_pool_async( transport: str = "grpc_asyncio", @@ -2762,6 +3146,9 @@ def test_query_deployed_models_empty_call(): with mock.patch.object( type(client.transport.query_deployed_models), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_deployed_models() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2788,6 +3175,9 @@ def test_query_deployed_models_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.query_deployed_models), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_deployed_models(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2797,6 +3187,46 @@ def test_query_deployed_models_non_empty_request_with_auto_populated_field(): ) +def test_query_deployed_models_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_deployed_models + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_deployed_models + ] = mock_rpc + request = {} + client.query_deployed_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_deployed_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_deployed_models_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2824,6 +3254,52 @@ async def test_query_deployed_models_empty_call_async(): assert args[0] == deployment_resource_pool_service.QueryDeployedModelsRequest() +@pytest.mark.asyncio +async def test_query_deployed_models_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_deployed_models + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_deployed_models + ] = mock_object + + request = {} + await client.query_deployed_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_deployed_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_query_deployed_models_async( transport: str = "grpc_asyncio", @@ -3262,6 +3738,51 @@ def test_create_deployment_resource_pool_rest(request_type): assert response.operation.name == "operations/spam" +def test_create_deployment_resource_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_deployment_resource_pool + ] = mock_rpc + + request = {} + client.create_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_deployment_resource_pool_rest_required_fields( request_type=deployment_resource_pool_service.CreateDeploymentResourcePoolRequest, ): @@ -3560,6 +4081,47 @@ def test_get_deployment_resource_pool_rest(request_type): assert response.name == "name_value" +def test_get_deployment_resource_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_deployment_resource_pool + ] = mock_rpc + + request = {} + client.get_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_deployment_resource_pool_rest_required_fields( request_type=deployment_resource_pool_service.GetDeploymentResourcePoolRequest, ): @@ -3846,6 +4408,47 @@ def test_list_deployment_resource_pools_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_deployment_resource_pools_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_deployment_resource_pools + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_deployment_resource_pools + ] = mock_rpc + + request = {} + client.list_deployment_resource_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_deployment_resource_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_deployment_resource_pools_rest_required_fields( request_type=deployment_resource_pool_service.ListDeploymentResourcePoolsRequest, ): @@ -4211,6 +4814,51 @@ def test_delete_deployment_resource_pool_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_deployment_resource_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_deployment_resource_pool + ] = mock_rpc + + request = {} + client.delete_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_deployment_resource_pool_rest_required_fields( request_type=deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest, ): @@ -4494,6 +5142,47 @@ def test_query_deployed_models_rest(request_type): assert response.total_endpoint_count == 2156 +def test_query_deployed_models_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_deployed_models + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_deployed_models + ] = mock_rpc + + request = {} + client.query_deployed_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_deployed_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_deployed_models_rest_required_fields( request_type=deployment_resource_pool_service.QueryDeployedModelsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py index 54f1f3f70b..d91d623d52 100644 --- a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -65,6 +65,7 @@ from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import machine_resources from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import service_networking from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -1210,6 +1211,9 @@ def test_create_endpoint_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1234,6 +1238,9 @@ def test_create_endpoint_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1243,6 +1250,45 @@ def test_create_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_create_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_endpoint] = mock_rpc + request = {} + client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1264,6 +1310,56 @@ async def test_create_endpoint_empty_call_async(): assert args[0] == endpoint_service.CreateEndpointRequest() +@pytest.mark.asyncio +async def test_create_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_endpoint + ] = mock_object + + request = {} + await client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest @@ -1524,6 +1620,9 @@ def test_get_endpoint_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1547,6 +1646,9 @@ def test_get_endpoint_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1555,6 +1657,41 @@ def test_get_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_get_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_endpoint] = mock_rpc + request = {} + client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1584,6 +1721,52 @@ async def test_get_endpoint_empty_call_async(): assert args[0] == endpoint_service.GetEndpointRequest() +@pytest.mark.asyncio +async def test_get_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_endpoint + ] = mock_object + + request = {} + await client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest @@ -1823,6 +2006,9 @@ def test_list_endpoints_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1849,6 +2035,9 @@ def test_list_endpoints_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_endpoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1860,6 +2049,41 @@ def test_list_endpoints_non_empty_request_with_auto_populated_field(): ) +def test_list_endpoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_endpoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_endpoints] = mock_rpc + request = {} + client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_endpoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1883,6 +2107,52 @@ async def test_list_endpoints_empty_call_async(): assert args[0] == endpoint_service.ListEndpointsRequest() +@pytest.mark.asyncio +async def test_list_endpoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_endpoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_endpoints + ] = mock_object + + request = {} + await client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_endpoints_async( transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest @@ -2316,6 +2586,9 @@ def test_update_endpoint_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2337,12 +2610,50 @@ def test_update_endpoint_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UpdateEndpointRequest() +def test_update_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_endpoint] = mock_rpc + request = {} + client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2372,6 +2683,52 @@ async def test_update_endpoint_empty_call_async(): assert args[0] == endpoint_service.UpdateEndpointRequest() +@pytest.mark.asyncio +async def test_update_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_endpoint + ] = mock_object + + request = {} + await client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest @@ -2622,6 +2979,9 @@ def test_delete_endpoint_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2645,6 +3005,9 @@ def test_delete_endpoint_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2653,6 +3016,45 @@ def test_delete_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_delete_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_endpoint] = mock_rpc + request = {} + client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2674,6 +3076,56 @@ async def test_delete_endpoint_empty_call_async(): assert args[0] == endpoint_service.DeleteEndpointRequest() +@pytest.mark.asyncio +async def test_delete_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_endpoint + ] = mock_object + + request = {} + await client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest @@ -2896,6 +3348,9 @@ def test_deploy_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.deploy_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2919,6 +3374,9 @@ def test_deploy_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.deploy_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2927,6 +3385,45 @@ def test_deploy_model_non_empty_request_with_auto_populated_field(): ) +def test_deploy_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.deploy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.deploy_model] = mock_rpc + request = {} + client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.deploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_deploy_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2948,6 +3445,56 @@ async def test_deploy_model_empty_call_async(): assert args[0] == endpoint_service.DeployModelRequest() +@pytest.mark.asyncio +async def test_deploy_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.deploy_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.deploy_model + ] = mock_object + + request = {} + await client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.deploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_deploy_model_async( transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest @@ -3226,6 +3773,9 @@ def test_undeploy_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.undeploy_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3250,6 +3800,9 @@ def test_undeploy_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.undeploy_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3259,6 +3812,45 @@ def test_undeploy_model_non_empty_request_with_auto_populated_field(): ) +def test_undeploy_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.undeploy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.undeploy_model] = mock_rpc + request = {} + client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.undeploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_undeploy_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3280,6 +3872,56 @@ async def test_undeploy_model_empty_call_async(): assert args[0] == endpoint_service.UndeployModelRequest() +@pytest.mark.asyncio +async def test_undeploy_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.undeploy_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.undeploy_model + ] = mock_object + + request = {} + await client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.undeploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_undeploy_model_async( transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest @@ -3526,6 +4168,9 @@ def test_mutate_deployed_model_empty_call(): with mock.patch.object( type(client.transport.mutate_deployed_model), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.mutate_deployed_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3551,6 +4196,9 @@ def test_mutate_deployed_model_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.mutate_deployed_model), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.mutate_deployed_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3559,6 +4207,50 @@ def test_mutate_deployed_model_non_empty_request_with_auto_populated_field(): ) +def test_mutate_deployed_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.mutate_deployed_model + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.mutate_deployed_model + ] = mock_rpc + request = {} + client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.mutate_deployed_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_mutate_deployed_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3582,6 +4274,56 @@ async def test_mutate_deployed_model_empty_call_async(): assert args[0] == endpoint_service.MutateDeployedModelRequest() +@pytest.mark.asyncio +async def test_mutate_deployed_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.mutate_deployed_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.mutate_deployed_model + ] = mock_object + + request = {} + await client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.mutate_deployed_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_mutate_deployed_model_async( transport: str = "grpc_asyncio", @@ -3939,6 +4681,13 @@ def test_create_endpoint_rest(request_type): "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "network": "network_value", "enable_private_service_connect": True, + "private_service_connect_config": { + "enable_private_service_connect": True, + "project_allowlist": [ + "project_allowlist_value1", + "project_allowlist_value2", + ], + }, "model_deployment_monitoring_job": "model_deployment_monitoring_job_value", "predict_request_response_logging_config": { "enabled": True, @@ -4033,6 +4782,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_endpoint] = mock_rpc + + request = {} + client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_endpoint_rest_required_fields( request_type=endpoint_service.CreateEndpointRequest, ): @@ -4326,6 +5115,42 @@ def test_get_endpoint_rest(request_type): ) +def test_get_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_endpoint] = mock_rpc + + request = {} + client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_endpoint_rest_required_fields( request_type=endpoint_service.GetEndpointRequest, ): @@ -4592,6 +5417,42 @@ def test_list_endpoints_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_endpoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_endpoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_endpoints] = mock_rpc + + request = {} + client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_endpoints_rest_required_fields( request_type=endpoint_service.ListEndpointsRequest, ): @@ -5009,6 +5870,13 @@ def test_update_endpoint_rest(request_type): "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "network": "network_value", "enable_private_service_connect": True, + "private_service_connect_config": { + "enable_private_service_connect": True, + "project_allowlist": [ + "project_allowlist_value1", + "project_allowlist_value2", + ], + }, "model_deployment_monitoring_job": "model_deployment_monitoring_job_value", "predict_request_response_logging_config": { "enabled": True, @@ -5123,6 +5991,42 @@ def get_message_fields(field): ) +def test_update_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_endpoint] = mock_rpc + + request = {} + client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_endpoint_rest_required_fields( request_type=endpoint_service.UpdateEndpointRequest, ): @@ -5397,6 +6301,46 @@ def test_delete_endpoint_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_endpoint] = mock_rpc + + request = {} + client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_endpoint_rest_required_fields( request_type=endpoint_service.DeleteEndpointRequest, ): @@ -5657,6 +6601,46 @@ def test_deploy_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_deploy_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.deploy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.deploy_model] = mock_rpc + + request = {} + client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.deploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_deploy_model_rest_required_fields( request_type=endpoint_service.DeployModelRequest, ): @@ -5943,6 +6927,46 @@ def test_undeploy_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_undeploy_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.undeploy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.undeploy_model] = mock_rpc + + request = {} + client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.undeploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_undeploy_model_rest_required_fields( request_type=endpoint_service.UndeployModelRequest, ): @@ -6221,6 +7245,51 @@ def test_mutate_deployed_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_mutate_deployed_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.mutate_deployed_model + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.mutate_deployed_model + ] = mock_rpc + + request = {} + client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.mutate_deployed_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_mutate_deployed_model_rest_required_fields( request_type=endpoint_service.MutateDeployedModelRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_feature_online_store_admin_service.py b/tests/unit/gapic/aiplatform_v1/test_feature_online_store_admin_service.py index 2ed51d72f1..30f1d4932a 100644 --- a/tests/unit/gapic/aiplatform_v1/test_feature_online_store_admin_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_feature_online_store_admin_service.py @@ -1293,6 +1293,9 @@ def test_create_feature_online_store_empty_call(): with mock.patch.object( type(client.transport.create_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_online_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1322,6 +1325,9 @@ def test_create_feature_online_store_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.create_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_online_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1333,6 +1339,50 @@ def test_create_feature_online_store_non_empty_request_with_auto_populated_field ) +def test_create_feature_online_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_online_store + ] = mock_rpc + request = {} + client.create_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_online_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1359,6 +1409,56 @@ async def test_create_feature_online_store_empty_call_async(): ) +@pytest.mark.asyncio +async def test_create_feature_online_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature_online_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature_online_store + ] = mock_object + + request = {} + await client.create_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_online_store_async( transport: str = "grpc_asyncio", @@ -1659,6 +1759,9 @@ def test_get_feature_online_store_empty_call(): with mock.patch.object( type(client.transport.get_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_online_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1686,6 +1789,9 @@ def test_get_feature_online_store_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_online_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1696,6 +1802,46 @@ def test_get_feature_online_store_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_online_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_online_store + ] = mock_rpc + request = {} + client.get_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_online_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1725,6 +1871,52 @@ async def test_get_feature_online_store_empty_call_async(): ) +@pytest.mark.asyncio +async def test_get_feature_online_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature_online_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature_online_store + ] = mock_object + + request = {} + await client.get_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_online_store_async( transport: str = "grpc_asyncio", @@ -1974,6 +2166,9 @@ def test_list_feature_online_stores_empty_call(): with mock.patch.object( type(client.transport.list_feature_online_stores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_online_stores() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2005,6 +2200,9 @@ def test_list_feature_online_stores_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.list_feature_online_stores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_online_stores(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2018,6 +2216,46 @@ def test_list_feature_online_stores_non_empty_request_with_auto_populated_field( ) +def test_list_feature_online_stores_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_online_stores + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_online_stores + ] = mock_rpc + request = {} + client.list_feature_online_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_online_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_online_stores_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2046,6 +2284,52 @@ async def test_list_feature_online_stores_empty_call_async(): ) +@pytest.mark.asyncio +async def test_list_feature_online_stores_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_feature_online_stores + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_feature_online_stores + ] = mock_object + + request = {} + await client.list_feature_online_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_feature_online_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_online_stores_async( transport: str = "grpc_asyncio", @@ -2494,6 +2778,9 @@ def test_update_feature_online_store_empty_call(): with mock.patch.object( type(client.transport.update_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_online_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2520,6 +2807,9 @@ def test_update_feature_online_store_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.update_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_online_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2529,6 +2819,50 @@ def test_update_feature_online_store_non_empty_request_with_auto_populated_field ) +def test_update_feature_online_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_online_store + ] = mock_rpc + request = {} + client.update_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_online_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2555,6 +2889,56 @@ async def test_update_feature_online_store_empty_call_async(): ) +@pytest.mark.asyncio +async def test_update_feature_online_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature_online_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature_online_store + ] = mock_object + + request = {} + await client.update_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_online_store_async( transport: str = "grpc_asyncio", @@ -2838,6 +3222,9 @@ def test_delete_feature_online_store_empty_call(): with mock.patch.object( type(client.transport.delete_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_online_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2866,6 +3253,9 @@ def test_delete_feature_online_store_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.delete_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_online_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2876,6 +3266,50 @@ def test_delete_feature_online_store_non_empty_request_with_auto_populated_field ) +def test_delete_feature_online_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_online_store + ] = mock_rpc + request = {} + client.delete_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_online_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2902,6 +3336,56 @@ async def test_delete_feature_online_store_empty_call_async(): ) +@pytest.mark.asyncio +async def test_delete_feature_online_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature_online_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature_online_store + ] = mock_object + + request = {} + await client.delete_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_online_store_async( transport: str = "grpc_asyncio", @@ -3149,6 +3633,9 @@ def test_create_feature_view_empty_call(): with mock.patch.object( type(client.transport.create_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3175,6 +3662,9 @@ def test_create_feature_view_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3184,6 +3674,49 @@ def test_create_feature_view_non_empty_request_with_auto_populated_field(): ) +def test_create_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_view + ] = mock_rpc + request = {} + client.create_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_view_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3207,6 +3740,56 @@ async def test_create_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.CreateFeatureViewRequest() +@pytest.mark.asyncio +async def test_create_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature_view + ] = mock_object + + request = {} + await client.create_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_view_async( transport: str = "grpc_asyncio", @@ -3489,6 +4072,9 @@ def test_get_feature_view_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature_view), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3512,6 +4098,9 @@ def test_get_feature_view_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature_view), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3520,12 +4109,49 @@ def test_get_feature_view_non_empty_request_with_auto_populated_field(): ) -@pytest.mark.asyncio -async def test_get_feature_view_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = FeatureOnlineStoreAdminServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), +def test_get_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature_view in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_view + ] = mock_rpc + request = {} + client.get_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_feature_view_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", ) @@ -3544,6 +4170,52 @@ async def test_get_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.GetFeatureViewRequest() +@pytest.mark.asyncio +async def test_get_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature_view + ] = mock_object + + request = {} + await client.get_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_view_async( transport: str = "grpc_asyncio", @@ -3779,6 +4451,9 @@ def test_list_feature_views_empty_call(): with mock.patch.object( type(client.transport.list_feature_views), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_views() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3807,6 +4482,9 @@ def test_list_feature_views_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_feature_views), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_views(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3818,6 +4496,45 @@ def test_list_feature_views_non_empty_request_with_auto_populated_field(): ) +def test_list_feature_views_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_views in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_views + ] = mock_rpc + request = {} + client.list_feature_views(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_views(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_views_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3843,6 +4560,52 @@ async def test_list_feature_views_empty_call_async(): assert args[0] == feature_online_store_admin_service.ListFeatureViewsRequest() +@pytest.mark.asyncio +async def test_list_feature_views_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_feature_views + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_feature_views + ] = mock_object + + request = {} + await client.list_feature_views(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_feature_views(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_views_async( transport: str = "grpc_asyncio", @@ -4287,6 +5050,9 @@ def test_update_feature_view_empty_call(): with mock.patch.object( type(client.transport.update_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4310,12 +5076,58 @@ def test_update_feature_view_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == feature_online_store_admin_service.UpdateFeatureViewRequest() +def test_update_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_view + ] = mock_rpc + request = {} + client.update_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_view_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4339,6 +5151,56 @@ async def test_update_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.UpdateFeatureViewRequest() +@pytest.mark.asyncio +async def test_update_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature_view + ] = mock_object + + request = {} + await client.update_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_view_async( transport: str = "grpc_asyncio", @@ -4610,6 +5472,9 @@ def test_delete_feature_view_empty_call(): with mock.patch.object( type(client.transport.delete_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4635,6 +5500,9 @@ def test_delete_feature_view_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4643,6 +5511,49 @@ def test_delete_feature_view_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_view + ] = mock_rpc + request = {} + client.delete_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_view_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4666,6 +5577,56 @@ async def test_delete_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.DeleteFeatureViewRequest() +@pytest.mark.asyncio +async def test_delete_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature_view + ] = mock_object + + request = {} + await client.delete_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_view_async( transport: str = "grpc_asyncio", @@ -4908,6 +5869,9 @@ def test_sync_feature_view_empty_call(): with mock.patch.object( type(client.transport.sync_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.sync_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4933,6 +5897,9 @@ def test_sync_feature_view_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.sync_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.sync_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4941,6 +5908,43 @@ def test_sync_feature_view_non_empty_request_with_auto_populated_field(): ) +def test_sync_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.sync_feature_view in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.sync_feature_view + ] = mock_rpc + request = {} + client.sync_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.sync_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_sync_feature_view_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4966,6 +5970,52 @@ async def test_sync_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.SyncFeatureViewRequest() +@pytest.mark.asyncio +async def test_sync_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.sync_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.sync_feature_view + ] = mock_object + + request = {} + await client.sync_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.sync_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_sync_feature_view_async( transport: str = "grpc_asyncio", @@ -5211,6 +6261,9 @@ def test_get_feature_view_sync_empty_call(): with mock.patch.object( type(client.transport.get_feature_view_sync), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_view_sync() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5236,6 +6289,9 @@ def test_get_feature_view_sync_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_feature_view_sync), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_view_sync(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5244,6 +6300,46 @@ def test_get_feature_view_sync_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_view_sync_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_feature_view_sync + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_view_sync + ] = mock_rpc + request = {} + client.get_feature_view_sync(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_view_sync(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_view_sync_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5269,6 +6365,52 @@ async def test_get_feature_view_sync_empty_call_async(): assert args[0] == feature_online_store_admin_service.GetFeatureViewSyncRequest() +@pytest.mark.asyncio +async def test_get_feature_view_sync_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature_view_sync + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature_view_sync + ] = mock_object + + request = {} + await client.get_feature_view_sync(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature_view_sync(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_view_sync_async( transport: str = "grpc_asyncio", @@ -5514,6 +6656,9 @@ def test_list_feature_view_syncs_empty_call(): with mock.patch.object( type(client.transport.list_feature_view_syncs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_view_syncs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5544,6 +6689,9 @@ def test_list_feature_view_syncs_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_feature_view_syncs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_view_syncs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5557,6 +6705,46 @@ def test_list_feature_view_syncs_non_empty_request_with_auto_populated_field(): ) +def test_list_feature_view_syncs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_view_syncs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_view_syncs + ] = mock_rpc + request = {} + client.list_feature_view_syncs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_view_syncs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_view_syncs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5576,13 +6764,59 @@ async def test_list_feature_view_syncs_empty_call_async(): next_page_token="next_page_token_value", ) ) - response = await client.list_feature_view_syncs() - call.assert_called() - _, args, _ = call.mock_calls[0] + response = await client.list_feature_view_syncs() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert ( + args[0] == feature_online_store_admin_service.ListFeatureViewSyncsRequest() + ) + + +@pytest.mark.asyncio +async def test_list_feature_view_syncs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached assert ( - args[0] == feature_online_store_admin_service.ListFeatureViewSyncsRequest() + client._client._transport.list_feature_view_syncs + in client._client._transport._wrapped_methods ) + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_feature_view_syncs + ] = mock_object + + request = {} + await client.list_feature_view_syncs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_feature_view_syncs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + @pytest.mark.asyncio async def test_list_feature_view_syncs_async( @@ -6108,6 +7342,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_online_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_online_store + ] = mock_rpc + + request = {} + client.create_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_online_store_rest_required_fields( request_type=feature_online_store_admin_service.CreateFeatureOnlineStoreRequest, ): @@ -6426,6 +7705,47 @@ def test_get_feature_online_store_rest(request_type): assert response.state == feature_online_store.FeatureOnlineStore.State.STABLE +def test_get_feature_online_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_online_store + ] = mock_rpc + + request = {} + client.get_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_online_store_rest_required_fields( request_type=feature_online_store_admin_service.GetFeatureOnlineStoreRequest, ): @@ -6706,6 +8026,47 @@ def test_list_feature_online_stores_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_feature_online_stores_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_online_stores + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_online_stores + ] = mock_rpc + + request = {} + client.list_feature_online_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_online_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_feature_online_stores_rest_required_fields( request_type=feature_online_store_admin_service.ListFeatureOnlineStoresRequest, ): @@ -7164,6 +8525,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_feature_online_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_online_store + ] = mock_rpc + + request = {} + client.update_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_online_store_rest_required_fields( request_type=feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest, ): @@ -7450,6 +8856,51 @@ def test_delete_feature_online_store_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_online_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_online_store + ] = mock_rpc + + request = {} + client.delete_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_online_store_rest_required_fields( request_type=feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest, ): @@ -7828,6 +9279,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_view + ] = mock_rpc + + request = {} + client.create_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.CreateFeatureViewRequest, ): @@ -8149,6 +9644,44 @@ def test_get_feature_view_rest(request_type): assert response.etag == "etag_value" +def test_get_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature_view in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_view + ] = mock_rpc + + request = {} + client.get_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.GetFeatureViewRequest, ): @@ -8426,6 +9959,46 @@ def test_list_feature_views_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_feature_views_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_views in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_views + ] = mock_rpc + + request = {} + client.list_feature_views(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_views(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_feature_views_rest_required_fields( request_type=feature_online_store_admin_service.ListFeatureViewsRequest, ): @@ -8892,6 +10465,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_view + ] = mock_rpc + + request = {} + client.update_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.UpdateFeatureViewRequest, ): @@ -9172,6 +10789,50 @@ def test_delete_feature_view_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_view + ] = mock_rpc + + request = {} + client.delete_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.DeleteFeatureViewRequest, ): @@ -9449,6 +11110,44 @@ def test_sync_feature_view_rest(request_type): assert response.feature_view_sync == "feature_view_sync_value" +def test_sync_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.sync_feature_view in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.sync_feature_view + ] = mock_rpc + + request = {} + client.sync_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.sync_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_sync_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.SyncFeatureViewRequest, ): @@ -9734,6 +11433,47 @@ def test_get_feature_view_sync_rest(request_type): assert response.name == "name_value" +def test_get_feature_view_sync_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_feature_view_sync + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_view_sync + ] = mock_rpc + + request = {} + client.get_feature_view_sync(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_view_sync(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_view_sync_rest_required_fields( request_type=feature_online_store_admin_service.GetFeatureViewSyncRequest, ): @@ -10014,6 +11754,47 @@ def test_list_feature_view_syncs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_feature_view_syncs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_view_syncs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_view_syncs + ] = mock_rpc + + request = {} + client.list_feature_view_syncs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_view_syncs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_feature_view_syncs_rest_required_fields( request_type=feature_online_store_admin_service.ListFeatureViewSyncsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_feature_online_store_service.py b/tests/unit/gapic/aiplatform_v1/test_feature_online_store_service.py index fec44c1515..39dff30efe 100644 --- a/tests/unit/gapic/aiplatform_v1/test_feature_online_store_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_feature_online_store_service.py @@ -1259,6 +1259,9 @@ def test_fetch_feature_values_empty_call(): with mock.patch.object( type(client.transport.fetch_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.fetch_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1284,6 +1287,9 @@ def test_fetch_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.fetch_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.fetch_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1292,6 +1298,45 @@ def test_fetch_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_fetch_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.fetch_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.fetch_feature_values + ] = mock_rpc + request = {} + client.fetch_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.fetch_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_fetch_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1315,6 +1360,52 @@ async def test_fetch_feature_values_empty_call_async(): assert args[0] == feature_online_store_service.FetchFeatureValuesRequest() +@pytest.mark.asyncio +async def test_fetch_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.fetch_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.fetch_feature_values + ] = mock_object + + request = {} + await client.fetch_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.fetch_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_fetch_feature_values_async( transport: str = "grpc_asyncio", @@ -1564,6 +1655,9 @@ def test_search_nearest_entities_empty_call(): with mock.patch.object( type(client.transport.search_nearest_entities), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_nearest_entities() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1589,6 +1683,9 @@ def test_search_nearest_entities_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.search_nearest_entities), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_nearest_entities(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1597,6 +1694,46 @@ def test_search_nearest_entities_non_empty_request_with_auto_populated_field(): ) +def test_search_nearest_entities_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_nearest_entities + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_nearest_entities + ] = mock_rpc + request = {} + client.search_nearest_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_nearest_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_nearest_entities_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1620,6 +1757,52 @@ async def test_search_nearest_entities_empty_call_async(): assert args[0] == feature_online_store_service.SearchNearestEntitiesRequest() +@pytest.mark.asyncio +async def test_search_nearest_entities_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_nearest_entities + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_nearest_entities + ] = mock_object + + request = {} + await client.search_nearest_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_nearest_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_nearest_entities_async( transport: str = "grpc_asyncio", @@ -1767,6 +1950,46 @@ def test_fetch_feature_values_rest(request_type): assert isinstance(response, feature_online_store_service.FetchFeatureValuesResponse) +def test_fetch_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.fetch_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.fetch_feature_values + ] = mock_rpc + + request = {} + client.fetch_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.fetch_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_fetch_feature_values_rest_required_fields( request_type=feature_online_store_service.FetchFeatureValuesRequest, ): @@ -2051,6 +2274,47 @@ def test_search_nearest_entities_rest(request_type): ) +def test_search_nearest_entities_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_nearest_entities + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_nearest_entities + ] = mock_rpc + + request = {} + client.search_nearest_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_nearest_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_nearest_entities_rest_required_fields( request_type=feature_online_store_service.SearchNearestEntitiesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_feature_registry_service.py b/tests/unit/gapic/aiplatform_v1/test_feature_registry_service.py index 3fe864f86c..1b868f7140 100644 --- a/tests/unit/gapic/aiplatform_v1/test_feature_registry_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_feature_registry_service.py @@ -1262,6 +1262,9 @@ def test_create_feature_group_empty_call(): with mock.patch.object( type(client.transport.create_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_group() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1288,6 +1291,9 @@ def test_create_feature_group_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_group(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1297,6 +1303,49 @@ def test_create_feature_group_non_empty_request_with_auto_populated_field(): ) +def test_create_feature_group_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_group + ] = mock_rpc + request = {} + client.create_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_group_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1320,6 +1369,56 @@ async def test_create_feature_group_empty_call_async(): assert args[0] == feature_registry_service.CreateFeatureGroupRequest() +@pytest.mark.asyncio +async def test_create_feature_group_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature_group + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature_group + ] = mock_object + + request = {} + await client.create_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_group_async( transport: str = "grpc_asyncio", @@ -1608,6 +1707,9 @@ def test_get_feature_group_empty_call(): with mock.patch.object( type(client.transport.get_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_group() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1633,6 +1735,9 @@ def test_get_feature_group_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_group(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1641,6 +1746,43 @@ def test_get_feature_group_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_group_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature_group in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_group + ] = mock_rpc + request = {} + client.get_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_group_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1668,6 +1810,52 @@ async def test_get_feature_group_empty_call_async(): assert args[0] == feature_registry_service.GetFeatureGroupRequest() +@pytest.mark.asyncio +async def test_get_feature_group_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature_group + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature_group + ] = mock_object + + request = {} + await client.get_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_group_async( transport: str = "grpc_asyncio", @@ -1915,6 +2103,9 @@ def test_list_feature_groups_empty_call(): with mock.patch.object( type(client.transport.list_feature_groups), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_groups() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1943,6 +2134,9 @@ def test_list_feature_groups_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_feature_groups), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_groups(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1954,6 +2148,45 @@ def test_list_feature_groups_non_empty_request_with_auto_populated_field(): ) +def test_list_feature_groups_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_groups in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_groups + ] = mock_rpc + request = {} + client.list_feature_groups(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_groups(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_groups_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1979,6 +2212,52 @@ async def test_list_feature_groups_empty_call_async(): assert args[0] == feature_registry_service.ListFeatureGroupsRequest() +@pytest.mark.asyncio +async def test_list_feature_groups_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_feature_groups + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_feature_groups + ] = mock_object + + request = {} + await client.list_feature_groups(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_feature_groups(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_groups_async( transport: str = "grpc_asyncio", @@ -2417,6 +2696,9 @@ def test_update_feature_group_empty_call(): with mock.patch.object( type(client.transport.update_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_group() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2440,12 +2722,58 @@ def test_update_feature_group_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_group(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == feature_registry_service.UpdateFeatureGroupRequest() +def test_update_feature_group_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_group + ] = mock_rpc + request = {} + client.update_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_group_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2469,6 +2797,56 @@ async def test_update_feature_group_empty_call_async(): assert args[0] == feature_registry_service.UpdateFeatureGroupRequest() +@pytest.mark.asyncio +async def test_update_feature_group_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature_group + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature_group + ] = mock_object + + request = {} + await client.update_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_group_async( transport: str = "grpc_asyncio", @@ -2740,6 +3118,9 @@ def test_delete_feature_group_empty_call(): with mock.patch.object( type(client.transport.delete_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_group() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2765,6 +3146,9 @@ def test_delete_feature_group_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_group(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2773,6 +3157,49 @@ def test_delete_feature_group_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_group_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_group + ] = mock_rpc + request = {} + client.delete_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_group_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2796,6 +3223,56 @@ async def test_delete_feature_group_empty_call_async(): assert args[0] == feature_registry_service.DeleteFeatureGroupRequest() +@pytest.mark.asyncio +async def test_delete_feature_group_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature_group + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature_group + ] = mock_object + + request = {} + await client.delete_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_group_async( transport: str = "grpc_asyncio", @@ -3039,6 +3516,9 @@ def test_create_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3063,6 +3543,9 @@ def test_create_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3072,6 +3555,45 @@ def test_create_feature_non_empty_request_with_auto_populated_field(): ) +def test_create_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_feature] = mock_rpc + request = {} + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3093,6 +3615,56 @@ async def test_create_feature_empty_call_async(): assert args[0] == featurestore_service.CreateFeatureRequest() +@pytest.mark.asyncio +async def test_create_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature + ] = mock_object + + request = {} + await client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_async( transport: str = "grpc_asyncio", @@ -3351,6 +3923,9 @@ def test_get_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3374,6 +3949,9 @@ def test_get_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3382,6 +3960,41 @@ def test_get_feature_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_feature] = mock_rpc + request = {} + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3411,6 +4024,52 @@ async def test_get_feature_empty_call_async(): assert args[0] == featurestore_service.GetFeatureRequest() +@pytest.mark.asyncio +async def test_get_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature + ] = mock_object + + request = {} + await client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_async( transport: str = "grpc_asyncio", request_type=featurestore_service.GetFeatureRequest @@ -3647,6 +4306,9 @@ def test_list_features_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_features() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3673,6 +4335,9 @@ def test_list_features_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_features(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3684,6 +4349,41 @@ def test_list_features_non_empty_request_with_auto_populated_field(): ) +def test_list_features_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_features] = mock_rpc + request = {} + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_features_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3707,6 +4407,52 @@ async def test_list_features_empty_call_async(): assert args[0] == featurestore_service.ListFeaturesRequest() +@pytest.mark.asyncio +async def test_list_features_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_features + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_features + ] = mock_object + + request = {} + await client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_features_async( transport: str = "grpc_asyncio", @@ -4123,6 +4869,9 @@ def test_update_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4144,12 +4893,54 @@ def test_update_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateFeatureRequest() +def test_update_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_feature] = mock_rpc + request = {} + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4171,6 +4962,56 @@ async def test_update_feature_empty_call_async(): assert args[0] == featurestore_service.UpdateFeatureRequest() +@pytest.mark.asyncio +async def test_update_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature + ] = mock_object + + request = {} + await client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_async( transport: str = "grpc_asyncio", @@ -4404,6 +5245,9 @@ def test_delete_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4427,6 +5271,9 @@ def test_delete_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4435,6 +5282,45 @@ def test_delete_feature_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_feature] = mock_rpc + request = {} + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4456,6 +5342,56 @@ async def test_delete_feature_empty_call_async(): assert args[0] == featurestore_service.DeleteFeatureRequest() +@pytest.mark.asyncio +async def test_delete_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature + ] = mock_object + + request = {} + await client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_async( transport: str = "grpc_asyncio", @@ -4755,6 +5691,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_group_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_group + ] = mock_rpc + + request = {} + client.create_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_group_rest_required_fields( request_type=feature_registry_service.CreateFeatureGroupRequest, ): @@ -5060,6 +6040,44 @@ def test_get_feature_group_rest(request_type): assert response.description == "description_value" +def test_get_feature_group_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature_group in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_group + ] = mock_rpc + + request = {} + client.get_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_group_rest_required_fields( request_type=feature_registry_service.GetFeatureGroupRequest, ): @@ -5332,6 +6350,46 @@ def test_list_feature_groups_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_feature_groups_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_groups in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_groups + ] = mock_rpc + + request = {} + client.list_feature_groups(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_groups(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_feature_groups_rest_required_fields( request_type=feature_registry_service.ListFeatureGroupsRequest, ): @@ -5766,6 +6824,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_feature_group_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_group + ] = mock_rpc + + request = {} + client.update_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_group_rest_required_fields( request_type=feature_registry_service.UpdateFeatureGroupRequest, ): @@ -6042,6 +7144,50 @@ def test_delete_feature_group_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_group_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_group + ] = mock_rpc + + request = {} + client.delete_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_group_rest_required_fields( request_type=feature_registry_service.DeleteFeatureGroupRequest, ): @@ -6403,6 +7549,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_feature] = mock_rpc + + request = {} + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_rest_required_fields( request_type=featurestore_service.CreateFeatureRequest, ): @@ -6713,6 +7899,42 @@ def test_get_feature_rest(request_type): assert response.point_of_contact == "point_of_contact_value" +def test_get_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_feature] = mock_rpc + + request = {} + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_rest_required_fields( request_type=featurestore_service.GetFeatureRequest, ): @@ -6984,6 +8206,42 @@ def test_list_features_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_features_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_features] = mock_rpc + + request = {} + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_features_rest_required_fields( request_type=featurestore_service.ListFeaturesRequest, ): @@ -7429,6 +8687,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_feature] = mock_rpc + + request = {} + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_rest_required_fields( request_type=featurestore_service.UpdateFeatureRequest, ): @@ -7698,6 +8996,46 @@ def test_delete_feature_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_feature] = mock_rpc + + request = {} + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_rest_required_fields( request_type=featurestore_service.DeleteFeatureRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py index e3f5fa9e21..3b5e77c953 100644 --- a/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py @@ -1289,6 +1289,9 @@ def test_read_feature_values_empty_call(): with mock.patch.object( type(client.transport.read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1315,6 +1318,9 @@ def test_read_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1324,6 +1330,45 @@ def test_read_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_read_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_feature_values + ] = mock_rpc + request = {} + client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1347,6 +1392,52 @@ async def test_read_feature_values_empty_call_async(): assert args[0] == featurestore_online_service.ReadFeatureValuesRequest() +@pytest.mark.asyncio +async def test_read_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_feature_values + ] = mock_object + + request = {} + await client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_feature_values_async( transport: str = "grpc_asyncio", @@ -1589,6 +1680,9 @@ def test_streaming_read_feature_values_empty_call(): with mock.patch.object( type(client.transport.streaming_read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.streaming_read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1616,6 +1710,9 @@ def test_streaming_read_feature_values_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.streaming_read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.streaming_read_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1624,6 +1721,46 @@ def test_streaming_read_feature_values_non_empty_request_with_auto_populated_fie ) +def test_streaming_read_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.streaming_read_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_read_feature_values + ] = mock_rpc + request = {} + client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_read_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1650,6 +1787,52 @@ async def test_streaming_read_feature_values_empty_call_async(): ) +@pytest.mark.asyncio +async def test_streaming_read_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.streaming_read_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.streaming_read_feature_values + ] = mock_object + + request = {} + await client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.streaming_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_read_feature_values_async( transport: str = "grpc_asyncio", @@ -1894,6 +2077,9 @@ def test_write_feature_values_empty_call(): with mock.patch.object( type(client.transport.write_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1919,6 +2105,9 @@ def test_write_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.write_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1927,6 +2116,45 @@ def test_write_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_write_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_feature_values + ] = mock_rpc + request = {} + client.write_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_write_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1950,6 +2178,52 @@ async def test_write_feature_values_empty_call_async(): assert args[0] == featurestore_online_service.WriteFeatureValuesRequest() +@pytest.mark.asyncio +async def test_write_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.write_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.write_feature_values + ] = mock_object + + request = {} + await client.write_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.write_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_write_feature_values_async( transport: str = "grpc_asyncio", @@ -2215,6 +2489,46 @@ def test_read_feature_values_rest(request_type): assert isinstance(response, featurestore_online_service.ReadFeatureValuesResponse) +def test_read_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_feature_values + ] = mock_rpc + + request = {} + client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_feature_values_rest_required_fields( request_type=featurestore_online_service.ReadFeatureValuesRequest, ): @@ -2517,6 +2831,47 @@ def test_streaming_read_feature_values_rest(request_type): assert isinstance(response, featurestore_online_service.ReadFeatureValuesResponse) +def test_streaming_read_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.streaming_read_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_read_feature_values + ] = mock_rpc + + request = {} + client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_streaming_read_feature_values_rest_required_fields( request_type=featurestore_online_service.StreamingReadFeatureValuesRequest, ): @@ -2821,6 +3176,46 @@ def test_write_feature_values_rest(request_type): assert isinstance(response, featurestore_online_service.WriteFeatureValuesResponse) +def test_write_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_feature_values + ] = mock_rpc + + request = {} + client.write_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_write_feature_values_rest_required_fields( request_type=featurestore_online_service.WriteFeatureValuesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py index 727370b1bb..28cef80a73 100644 --- a/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py @@ -1263,6 +1263,9 @@ def test_create_featurestore_empty_call(): with mock.patch.object( type(client.transport.create_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1289,6 +1292,9 @@ def test_create_featurestore_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_featurestore(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1298,6 +1304,49 @@ def test_create_featurestore_non_empty_request_with_auto_populated_field(): ) +def test_create_featurestore_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_featurestore + ] = mock_rpc + request = {} + client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_featurestore_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1321,6 +1370,56 @@ async def test_create_featurestore_empty_call_async(): assert args[0] == featurestore_service.CreateFeaturestoreRequest() +@pytest.mark.asyncio +async def test_create_featurestore_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_featurestore + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_featurestore + ] = mock_object + + request = {} + await client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_featurestore_async( transport: str = "grpc_asyncio", @@ -1583,6 +1682,9 @@ def test_get_featurestore_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1606,6 +1708,9 @@ def test_get_featurestore_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_featurestore(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1614,6 +1719,43 @@ def test_get_featurestore_non_empty_request_with_auto_populated_field(): ) +def test_get_featurestore_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_featurestore in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_featurestore + ] = mock_rpc + request = {} + client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_featurestore_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1640,6 +1782,52 @@ async def test_get_featurestore_empty_call_async(): assert args[0] == featurestore_service.GetFeaturestoreRequest() +@pytest.mark.asyncio +async def test_get_featurestore_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_featurestore + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_featurestore + ] = mock_object + + request = {} + await client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_featurestore_async( transport: str = "grpc_asyncio", @@ -1879,6 +2067,9 @@ def test_list_featurestores_empty_call(): with mock.patch.object( type(client.transport.list_featurestores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_featurestores() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1907,6 +2098,9 @@ def test_list_featurestores_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_featurestores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_featurestores(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1918,6 +2112,45 @@ def test_list_featurestores_non_empty_request_with_auto_populated_field(): ) +def test_list_featurestores_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_featurestores in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_featurestores + ] = mock_rpc + request = {} + client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_featurestores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_featurestores_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1943,6 +2176,52 @@ async def test_list_featurestores_empty_call_async(): assert args[0] == featurestore_service.ListFeaturestoresRequest() +@pytest.mark.asyncio +async def test_list_featurestores_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_featurestores + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_featurestores + ] = mock_object + + request = {} + await client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_featurestores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_featurestores_async( transport: str = "grpc_asyncio", @@ -2381,6 +2660,9 @@ def test_update_featurestore_empty_call(): with mock.patch.object( type(client.transport.update_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2404,12 +2686,58 @@ def test_update_featurestore_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_featurestore(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateFeaturestoreRequest() +def test_update_featurestore_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_featurestore + ] = mock_rpc + request = {} + client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_featurestore_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2433,6 +2761,56 @@ async def test_update_featurestore_empty_call_async(): assert args[0] == featurestore_service.UpdateFeaturestoreRequest() +@pytest.mark.asyncio +async def test_update_featurestore_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_featurestore + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_featurestore + ] = mock_object + + request = {} + await client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_featurestore_async( transport: str = "grpc_asyncio", @@ -2680,6 +3058,9 @@ def test_delete_featurestore_empty_call(): with mock.patch.object( type(client.transport.delete_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2705,6 +3086,9 @@ def test_delete_featurestore_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_featurestore(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2713,6 +3097,49 @@ def test_delete_featurestore_non_empty_request_with_auto_populated_field(): ) +def test_delete_featurestore_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_featurestore + ] = mock_rpc + request = {} + client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_featurestore_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2736,6 +3163,56 @@ async def test_delete_featurestore_empty_call_async(): assert args[0] == featurestore_service.DeleteFeaturestoreRequest() +@pytest.mark.asyncio +async def test_delete_featurestore_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_featurestore + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_featurestore + ] = mock_object + + request = {} + await client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_featurestore_async( transport: str = "grpc_asyncio", @@ -2983,6 +3460,9 @@ def test_create_entity_type_empty_call(): with mock.patch.object( type(client.transport.create_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3009,6 +3489,9 @@ def test_create_entity_type_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_entity_type(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3018,6 +3501,49 @@ def test_create_entity_type_non_empty_request_with_auto_populated_field(): ) +def test_create_entity_type_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_entity_type + ] = mock_rpc + request = {} + client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_entity_type_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3041,6 +3567,56 @@ async def test_create_entity_type_empty_call_async(): assert args[0] == featurestore_service.CreateEntityTypeRequest() +@pytest.mark.asyncio +async def test_create_entity_type_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_entity_type + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_entity_type + ] = mock_object + + request = {} + await client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_entity_type_async( transport: str = "grpc_asyncio", @@ -3303,6 +3879,9 @@ def test_get_entity_type_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3326,6 +3905,9 @@ def test_get_entity_type_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_entity_type(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3334,14 +3916,49 @@ def test_get_entity_type_non_empty_request_with_auto_populated_field(): ) -@pytest.mark.asyncio -async def test_get_entity_type_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = FeaturestoreServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) +def test_get_entity_type_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_entity_type in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_entity_type] = mock_rpc + request = {} + client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_entity_type_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: @@ -3360,6 +3977,52 @@ async def test_get_entity_type_empty_call_async(): assert args[0] == featurestore_service.GetEntityTypeRequest() +@pytest.mark.asyncio +async def test_get_entity_type_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_entity_type + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_entity_type + ] = mock_object + + request = {} + await client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_entity_type_async( transport: str = "grpc_asyncio", @@ -3599,6 +4262,9 @@ def test_list_entity_types_empty_call(): with mock.patch.object( type(client.transport.list_entity_types), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_entity_types() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3627,6 +4293,9 @@ def test_list_entity_types_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_entity_types), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_entity_types(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3638,6 +4307,43 @@ def test_list_entity_types_non_empty_request_with_auto_populated_field(): ) +def test_list_entity_types_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_entity_types in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_entity_types + ] = mock_rpc + request = {} + client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_entity_types(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_entity_types_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3663,6 +4369,52 @@ async def test_list_entity_types_empty_call_async(): assert args[0] == featurestore_service.ListEntityTypesRequest() +@pytest.mark.asyncio +async def test_list_entity_types_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_entity_types + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_entity_types + ] = mock_object + + request = {} + await client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_entity_types(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_entity_types_async( transport: str = "grpc_asyncio", @@ -4110,6 +4862,9 @@ def test_update_entity_type_empty_call(): with mock.patch.object( type(client.transport.update_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4133,12 +4888,54 @@ def test_update_entity_type_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_entity_type(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateEntityTypeRequest() +def test_update_entity_type_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_entity_type + ] = mock_rpc + request = {} + client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_entity_type_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4167,6 +4964,52 @@ async def test_update_entity_type_empty_call_async(): assert args[0] == featurestore_service.UpdateEntityTypeRequest() +@pytest.mark.asyncio +async def test_update_entity_type_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_entity_type + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_entity_type + ] = mock_object + + request = {} + await client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_entity_type_async( transport: str = "grpc_asyncio", @@ -4423,6 +5266,9 @@ def test_delete_entity_type_empty_call(): with mock.patch.object( type(client.transport.delete_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4448,6 +5294,9 @@ def test_delete_entity_type_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_entity_type(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4456,6 +5305,49 @@ def test_delete_entity_type_non_empty_request_with_auto_populated_field(): ) +def test_delete_entity_type_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_entity_type + ] = mock_rpc + request = {} + client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_entity_type_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4479,6 +5371,56 @@ async def test_delete_entity_type_empty_call_async(): assert args[0] == featurestore_service.DeleteEntityTypeRequest() +@pytest.mark.asyncio +async def test_delete_entity_type_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_entity_type + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_entity_type + ] = mock_object + + request = {} + await client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_entity_type_async( transport: str = "grpc_asyncio", @@ -4722,6 +5664,9 @@ def test_create_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4746,6 +5691,9 @@ def test_create_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4755,6 +5703,45 @@ def test_create_feature_non_empty_request_with_auto_populated_field(): ) +def test_create_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_feature] = mock_rpc + request = {} + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4776,6 +5763,56 @@ async def test_create_feature_empty_call_async(): assert args[0] == featurestore_service.CreateFeatureRequest() +@pytest.mark.asyncio +async def test_create_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature + ] = mock_object + + request = {} + await client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_async( transport: str = "grpc_asyncio", @@ -5023,6 +6060,9 @@ def test_batch_create_features_empty_call(): with mock.patch.object( type(client.transport.batch_create_features), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_features() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5048,6 +6088,9 @@ def test_batch_create_features_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.batch_create_features), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_features(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5056,6 +6099,50 @@ def test_batch_create_features_non_empty_request_with_auto_populated_field(): ) +def test_batch_create_features_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_features + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_features + ] = mock_rpc + request = {} + client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_create_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_features_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5079,6 +6166,56 @@ async def test_batch_create_features_empty_call_async(): assert args[0] == featurestore_service.BatchCreateFeaturesRequest() +@pytest.mark.asyncio +async def test_batch_create_features_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_create_features + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_create_features + ] = mock_object + + request = {} + await client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_create_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_features_async( transport: str = "grpc_asyncio", @@ -5337,6 +6474,9 @@ def test_get_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5360,6 +6500,9 @@ def test_get_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5368,6 +6511,41 @@ def test_get_feature_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_feature] = mock_rpc + request = {} + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5397,6 +6575,52 @@ async def test_get_feature_empty_call_async(): assert args[0] == featurestore_service.GetFeatureRequest() +@pytest.mark.asyncio +async def test_get_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature + ] = mock_object + + request = {} + await client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_async( transport: str = "grpc_asyncio", request_type=featurestore_service.GetFeatureRequest @@ -5633,6 +6857,9 @@ def test_list_features_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_features() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5659,6 +6886,9 @@ def test_list_features_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_features(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5670,6 +6900,41 @@ def test_list_features_non_empty_request_with_auto_populated_field(): ) +def test_list_features_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_features] = mock_rpc + request = {} + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_features_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5693,6 +6958,52 @@ async def test_list_features_empty_call_async(): assert args[0] == featurestore_service.ListFeaturesRequest() +@pytest.mark.asyncio +async def test_list_features_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_features + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_features + ] = mock_object + + request = {} + await client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_features_async( transport: str = "grpc_asyncio", @@ -6124,6 +7435,9 @@ def test_update_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6145,12 +7459,50 @@ def test_update_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateFeatureRequest() +def test_update_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_feature] = mock_rpc + request = {} + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6180,6 +7532,52 @@ async def test_update_feature_empty_call_async(): assert args[0] == featurestore_service.UpdateFeatureRequest() +@pytest.mark.asyncio +async def test_update_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature + ] = mock_object + + request = {} + await client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_async( transport: str = "grpc_asyncio", @@ -6424,6 +7822,9 @@ def test_delete_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6447,6 +7848,9 @@ def test_delete_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6455,6 +7859,45 @@ def test_delete_feature_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_feature] = mock_rpc + request = {} + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6476,6 +7919,56 @@ async def test_delete_feature_empty_call_async(): assert args[0] == featurestore_service.DeleteFeatureRequest() +@pytest.mark.asyncio +async def test_delete_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature + ] = mock_object + + request = {} + await client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_async( transport: str = "grpc_asyncio", @@ -6703,6 +8196,9 @@ def test_import_feature_values_empty_call(): with mock.patch.object( type(client.transport.import_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6730,6 +8226,9 @@ def test_import_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.import_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6740,6 +8239,50 @@ def test_import_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_import_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.import_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_feature_values + ] = mock_rpc + request = {} + client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6763,6 +8306,56 @@ async def test_import_feature_values_empty_call_async(): assert args[0] == featurestore_service.ImportFeatureValuesRequest() +@pytest.mark.asyncio +async def test_import_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_feature_values + ] = mock_object + + request = {} + await client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.import_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_feature_values_async( transport: str = "grpc_asyncio", @@ -7000,6 +8593,9 @@ def test_batch_read_feature_values_empty_call(): with mock.patch.object( type(client.transport.batch_read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7025,6 +8621,9 @@ def test_batch_read_feature_values_non_empty_request_with_auto_populated_field() with mock.patch.object( type(client.transport.batch_read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_read_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7033,6 +8632,50 @@ def test_batch_read_feature_values_non_empty_request_with_auto_populated_field() ) +def test_batch_read_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_read_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_read_feature_values + ] = mock_rpc + request = {} + client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_read_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7056,6 +8699,56 @@ async def test_batch_read_feature_values_empty_call_async(): assert args[0] == featurestore_service.BatchReadFeatureValuesRequest() +@pytest.mark.asyncio +async def test_batch_read_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_read_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_read_feature_values + ] = mock_object + + request = {} + await client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_read_feature_values_async( transport: str = "grpc_asyncio", @@ -7293,6 +8986,9 @@ def test_export_feature_values_empty_call(): with mock.patch.object( type(client.transport.export_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7318,6 +9014,9 @@ def test_export_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.export_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7326,6 +9025,50 @@ def test_export_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_export_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.export_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.export_feature_values + ] = mock_rpc + request = {} + client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7349,6 +9092,56 @@ async def test_export_feature_values_empty_call_async(): assert args[0] == featurestore_service.ExportFeatureValuesRequest() +@pytest.mark.asyncio +async def test_export_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_feature_values + ] = mock_object + + request = {} + await client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.export_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_export_feature_values_async( transport: str = "grpc_asyncio", @@ -7586,6 +9379,9 @@ def test_delete_feature_values_empty_call(): with mock.patch.object( type(client.transport.delete_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7611,6 +9407,9 @@ def test_delete_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7619,6 +9418,50 @@ def test_delete_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_values + ] = mock_rpc + request = {} + client.delete_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7642,6 +9485,56 @@ async def test_delete_feature_values_empty_call_async(): assert args[0] == featurestore_service.DeleteFeatureValuesRequest() +@pytest.mark.asyncio +async def test_delete_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature_values + ] = mock_object + + request = {} + await client.delete_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_values_async( transport: str = "grpc_asyncio", @@ -7878,6 +9771,9 @@ def test_search_features_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.search_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_features() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7903,6 +9799,9 @@ def test_search_features_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.search_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_features(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7913,6 +9812,41 @@ def test_search_features_non_empty_request_with_auto_populated_field(): ) +def test_search_features_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.search_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.search_features] = mock_rpc + request = {} + client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_features_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7936,6 +9870,52 @@ async def test_search_features_empty_call_async(): assert args[0] == featurestore_service.SearchFeaturesRequest() +@pytest.mark.asyncio +async def test_search_features_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_features + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_features + ] = mock_object + + request = {} + await client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_features_async( transport: str = "grpc_asyncio", @@ -8441,6 +10421,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_featurestore_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_featurestore + ] = mock_rpc + + request = {} + client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_featurestore_rest_required_fields( request_type=featurestore_service.CreateFeaturestoreRequest, ): @@ -8739,6 +10763,44 @@ def test_get_featurestore_rest(request_type): assert response.online_storage_ttl_days == 2460 +def test_get_featurestore_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_featurestore in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_featurestore + ] = mock_rpc + + request = {} + client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_featurestore_rest_required_fields( request_type=featurestore_service.GetFeaturestoreRequest, ): @@ -9008,6 +11070,46 @@ def test_list_featurestores_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_featurestores_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_featurestores in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_featurestores + ] = mock_rpc + + request = {} + client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_featurestores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_featurestores_rest_required_fields( request_type=featurestore_service.ListFeaturestoresRequest, ): @@ -9443,6 +11545,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_featurestore_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_featurestore + ] = mock_rpc + + request = {} + client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_featurestore_rest_required_fields( request_type=featurestore_service.UpdateFeaturestoreRequest, ): @@ -9710,6 +11856,50 @@ def test_delete_featurestore_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_featurestore_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_featurestore + ] = mock_rpc + + request = {} + client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_featurestore_rest_required_fields( request_type=featurestore_service.DeleteFeaturestoreRequest, ): @@ -10063,6 +12253,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_entity_type_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_entity_type + ] = mock_rpc + + request = {} + client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_entity_type_rest_required_fields( request_type=featurestore_service.CreateEntityTypeRequest, ): @@ -10354,16 +12588,52 @@ def test_get_entity_type_rest(request_type): return_value = entity_type.EntityType.pb(return_value) json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.get_entity_type(request) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_entity_type(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, entity_type.EntityType) + assert response.name == "name_value" + assert response.description == "description_value" + assert response.etag == "etag_value" + assert response.offline_storage_ttl_days == 2554 + + +def test_get_entity_type_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_entity_type in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_entity_type] = mock_rpc + + request = {} + client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_entity_type(request) - # Establish that the response is the type that we expect. - assert isinstance(response, entity_type.EntityType) - assert response.name == "name_value" - assert response.description == "description_value" - assert response.etag == "etag_value" - assert response.offline_storage_ttl_days == 2554 + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_get_entity_type_rest_required_fields( @@ -10639,6 +12909,44 @@ def test_list_entity_types_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_entity_types_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_entity_types in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_entity_types + ] = mock_rpc + + request = {} + client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_entity_types(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_entity_types_rest_required_fields( request_type=featurestore_service.ListEntityTypesRequest, ): @@ -11088,6 +13396,46 @@ def get_message_fields(field): assert response.offline_storage_ttl_days == 2554 +def test_update_entity_type_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_entity_type + ] = mock_rpc + + request = {} + client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_entity_type_rest_required_fields( request_type=featurestore_service.UpdateEntityTypeRequest, ): @@ -11360,6 +13708,50 @@ def test_delete_entity_type_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_entity_type_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_entity_type + ] = mock_rpc + + request = {} + client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_entity_type_rest_required_fields( request_type=featurestore_service.DeleteEntityTypeRequest, ): @@ -11722,6 +14114,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_feature] = mock_rpc + + request = {} + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_rest_required_fields( request_type=featurestore_service.CreateFeatureRequest, ): @@ -12015,6 +14447,51 @@ def test_batch_create_features_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_create_features_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_features + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_features + ] = mock_rpc + + request = {} + client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_create_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_create_features_rest_required_fields( request_type=featurestore_service.BatchCreateFeaturesRequest, ): @@ -12309,6 +14786,42 @@ def test_get_feature_rest(request_type): assert response.point_of_contact == "point_of_contact_value" +def test_get_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_feature] = mock_rpc + + request = {} + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_rest_required_fields( request_type=featurestore_service.GetFeatureRequest, ): @@ -12580,6 +15093,42 @@ def test_list_features_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_features_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_features] = mock_rpc + + request = {} + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_features_rest_required_fields( request_type=featurestore_service.ListFeaturesRequest, ): @@ -13042,6 +15591,42 @@ def get_message_fields(field): assert response.point_of_contact == "point_of_contact_value" +def test_update_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_feature] = mock_rpc + + request = {} + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_rest_required_fields( request_type=featurestore_service.UpdateFeatureRequest, ): @@ -13312,6 +15897,46 @@ def test_delete_feature_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_feature] = mock_rpc + + request = {} + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_rest_required_fields( request_type=featurestore_service.DeleteFeatureRequest, ): @@ -13577,6 +16202,51 @@ def test_import_feature_values_rest(request_type): assert response.operation.name == "operations/spam" +def test_import_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.import_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_feature_values + ] = mock_rpc + + request = {} + client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_import_feature_values_rest_required_fields( request_type=featurestore_service.ImportFeatureValuesRequest, ): @@ -13852,6 +16522,51 @@ def test_batch_read_feature_values_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_read_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_read_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_read_feature_values + ] = mock_rpc + + request = {} + client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_read_feature_values_rest_required_fields( request_type=featurestore_service.BatchReadFeatureValuesRequest, ): @@ -14128,6 +16843,51 @@ def test_export_feature_values_rest(request_type): assert response.operation.name == "operations/spam" +def test_export_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.export_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.export_feature_values + ] = mock_rpc + + request = {} + client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_export_feature_values_rest_required_fields( request_type=featurestore_service.ExportFeatureValuesRequest, ): @@ -14404,6 +17164,51 @@ def test_delete_feature_values_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_values + ] = mock_rpc + + request = {} + client.delete_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_values_rest_required_fields( request_type=featurestore_service.DeleteFeatureValuesRequest, ): @@ -14674,6 +17479,42 @@ def test_search_features_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_features_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.search_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.search_features] = mock_rpc + + request = {} + client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_features_rest_required_fields( request_type=featurestore_service.SearchFeaturesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_gen_ai_tuning_service.py b/tests/unit/gapic/aiplatform_v1/test_gen_ai_tuning_service.py index e1bb44ac27..cbac28c368 100644 --- a/tests/unit/gapic/aiplatform_v1/test_gen_ai_tuning_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_gen_ai_tuning_service.py @@ -1239,6 +1239,9 @@ def test_create_tuning_job_empty_call(): with mock.patch.object( type(client.transport.create_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1264,6 +1267,9 @@ def test_create_tuning_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1272,6 +1278,43 @@ def test_create_tuning_job_non_empty_request_with_auto_populated_field(): ) +def test_create_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenAiTuningServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_tuning_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tuning_job + ] = mock_rpc + request = {} + client.create_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1301,6 +1344,52 @@ async def test_create_tuning_job_empty_call_async(): assert args[0] == genai_tuning_service.CreateTuningJobRequest() +@pytest.mark.asyncio +async def test_create_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenAiTuningServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tuning_job + ] = mock_object + + request = {} + await client.create_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tuning_job_async( transport: str = "grpc_asyncio", @@ -1567,6 +1656,9 @@ def test_get_tuning_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_tuning_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1590,6 +1682,9 @@ def test_get_tuning_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_tuning_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1598,6 +1693,41 @@ def test_get_tuning_job_non_empty_request_with_auto_populated_field(): ) +def test_get_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenAiTuningServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_tuning_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_tuning_job] = mock_rpc + request = {} + client.get_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1625,6 +1755,52 @@ async def test_get_tuning_job_empty_call_async(): assert args[0] == genai_tuning_service.GetTuningJobRequest() +@pytest.mark.asyncio +async def test_get_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenAiTuningServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tuning_job + ] = mock_object + + request = {} + await client.get_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tuning_job_async( transport: str = "grpc_asyncio", @@ -1862,6 +2038,9 @@ def test_list_tuning_jobs_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_tuning_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tuning_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1887,6 +2066,9 @@ def test_list_tuning_jobs_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_tuning_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tuning_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1897,6 +2079,43 @@ def test_list_tuning_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_tuning_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenAiTuningServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_tuning_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tuning_jobs + ] = mock_rpc + request = {} + client.list_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tuning_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1920,6 +2139,52 @@ async def test_list_tuning_jobs_empty_call_async(): assert args[0] == genai_tuning_service.ListTuningJobsRequest() +@pytest.mark.asyncio +async def test_list_tuning_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenAiTuningServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tuning_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tuning_jobs + ] = mock_object + + request = {} + await client.list_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tuning_jobs_async( transport: str = "grpc_asyncio", @@ -2340,6 +2605,9 @@ def test_cancel_tuning_job_empty_call(): with mock.patch.object( type(client.transport.cancel_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2365,6 +2633,9 @@ def test_cancel_tuning_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2373,6 +2644,43 @@ def test_cancel_tuning_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenAiTuningServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_tuning_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_tuning_job + ] = mock_rpc + request = {} + client.cancel_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2394,6 +2702,52 @@ async def test_cancel_tuning_job_empty_call_async(): assert args[0] == genai_tuning_service.CancelTuningJobRequest() +@pytest.mark.asyncio +async def test_cancel_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenAiTuningServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_tuning_job + ] = mock_object + + request = {} + await client.cancel_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_tuning_job_async( transport: str = "grpc_asyncio", @@ -2777,6 +3131,44 @@ def get_message_fields(field): assert response.experiment == "experiment_value" +def test_create_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenAiTuningServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_tuning_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tuning_job + ] = mock_rpc + + request = {} + client.create_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tuning_job_rest_required_fields( request_type=genai_tuning_service.CreateTuningJobRequest, ): @@ -3063,6 +3455,42 @@ def test_get_tuning_job_rest(request_type): assert response.experiment == "experiment_value" +def test_get_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenAiTuningServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_tuning_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_tuning_job] = mock_rpc + + request = {} + client.get_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tuning_job_rest_required_fields( request_type=genai_tuning_service.GetTuningJobRequest, ): @@ -3329,6 +3757,44 @@ def test_list_tuning_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tuning_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenAiTuningServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_tuning_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tuning_jobs + ] = mock_rpc + + request = {} + client.list_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tuning_jobs_rest_required_fields( request_type=genai_tuning_service.ListTuningJobsRequest, ): @@ -3664,6 +4130,44 @@ def test_cancel_tuning_job_rest(request_type): assert response is None +def test_cancel_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenAiTuningServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_tuning_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_tuning_job + ] = mock_rpc + + request = {} + client.cancel_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_tuning_job_rest_required_fields( request_type=genai_tuning_service.CancelTuningJobRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py index 54f585baa4..22e8bf60a6 100644 --- a/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py @@ -1257,6 +1257,9 @@ def test_create_index_endpoint_empty_call(): with mock.patch.object( type(client.transport.create_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1282,6 +1285,9 @@ def test_create_index_endpoint_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1290,6 +1296,50 @@ def test_create_index_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_create_index_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_index_endpoint + ] = mock_rpc + request = {} + client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1313,6 +1363,56 @@ async def test_create_index_endpoint_empty_call_async(): assert args[0] == index_endpoint_service.CreateIndexEndpointRequest() +@pytest.mark.asyncio +async def test_create_index_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_index_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_index_endpoint + ] = mock_object + + request = {} + await client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_endpoint_async( transport: str = "grpc_asyncio", @@ -1577,6 +1677,9 @@ def test_get_index_endpoint_empty_call(): with mock.patch.object( type(client.transport.get_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1602,6 +1705,9 @@ def test_get_index_endpoint_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1610,6 +1716,45 @@ def test_get_index_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_get_index_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_index_endpoint in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_index_endpoint + ] = mock_rpc + request = {} + client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1642,6 +1787,52 @@ async def test_get_index_endpoint_empty_call_async(): assert args[0] == index_endpoint_service.GetIndexEndpointRequest() +@pytest.mark.asyncio +async def test_get_index_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_index_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_index_endpoint + ] = mock_object + + request = {} + await client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_endpoint_async( transport: str = "grpc_asyncio", @@ -1899,6 +2090,9 @@ def test_list_index_endpoints_empty_call(): with mock.patch.object( type(client.transport.list_index_endpoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_index_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1926,6 +2120,9 @@ def test_list_index_endpoints_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_index_endpoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_index_endpoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1936,6 +2133,45 @@ def test_list_index_endpoints_non_empty_request_with_auto_populated_field(): ) +def test_list_index_endpoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_index_endpoints in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_index_endpoints + ] = mock_rpc + request = {} + client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_index_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_index_endpoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1961,6 +2197,52 @@ async def test_list_index_endpoints_empty_call_async(): assert args[0] == index_endpoint_service.ListIndexEndpointsRequest() +@pytest.mark.asyncio +async def test_list_index_endpoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_index_endpoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_index_endpoints + ] = mock_object + + request = {} + await client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_index_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_index_endpoints_async( transport: str = "grpc_asyncio", @@ -2416,6 +2698,9 @@ def test_update_index_endpoint_empty_call(): with mock.patch.object( type(client.transport.update_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2439,12 +2724,55 @@ def test_update_index_endpoint_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_index_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.UpdateIndexEndpointRequest() +def test_update_index_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_index_endpoint + ] = mock_rpc + request = {} + client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_index_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2477,6 +2805,52 @@ async def test_update_index_endpoint_empty_call_async(): assert args[0] == index_endpoint_service.UpdateIndexEndpointRequest() +@pytest.mark.asyncio +async def test_update_index_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_index_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_index_endpoint + ] = mock_object + + request = {} + await client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_index_endpoint_async( transport: str = "grpc_asyncio", @@ -2741,6 +3115,9 @@ def test_delete_index_endpoint_empty_call(): with mock.patch.object( type(client.transport.delete_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2766,6 +3143,9 @@ def test_delete_index_endpoint_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2774,6 +3154,50 @@ def test_delete_index_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_delete_index_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_index_endpoint + ] = mock_rpc + request = {} + client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2797,6 +3221,56 @@ async def test_delete_index_endpoint_empty_call_async(): assert args[0] == index_endpoint_service.DeleteIndexEndpointRequest() +@pytest.mark.asyncio +async def test_delete_index_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_index_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_index_endpoint + ] = mock_object + + request = {} + await client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_endpoint_async( transport: str = "grpc_asyncio", @@ -3030,6 +3504,9 @@ def test_deploy_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.deploy_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3053,6 +3530,9 @@ def test_deploy_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.deploy_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3061,6 +3541,45 @@ def test_deploy_index_non_empty_request_with_auto_populated_field(): ) +def test_deploy_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.deploy_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.deploy_index] = mock_rpc + request = {} + client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.deploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_deploy_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3082,6 +3601,56 @@ async def test_deploy_index_empty_call_async(): assert args[0] == index_endpoint_service.DeployIndexRequest() +@pytest.mark.asyncio +async def test_deploy_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.deploy_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.deploy_index + ] = mock_object + + request = {} + await client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.deploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_deploy_index_async( transport: str = "grpc_asyncio", @@ -3315,6 +3884,9 @@ def test_undeploy_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.undeploy_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3339,6 +3911,9 @@ def test_undeploy_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.undeploy_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3348,6 +3923,45 @@ def test_undeploy_index_non_empty_request_with_auto_populated_field(): ) +def test_undeploy_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.undeploy_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.undeploy_index] = mock_rpc + request = {} + client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.undeploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_undeploy_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3369,6 +3983,56 @@ async def test_undeploy_index_empty_call_async(): assert args[0] == index_endpoint_service.UndeployIndexRequest() +@pytest.mark.asyncio +async def test_undeploy_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.undeploy_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.undeploy_index + ] = mock_object + + request = {} + await client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.undeploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_undeploy_index_async( transport: str = "grpc_asyncio", @@ -3606,6 +4270,9 @@ def test_mutate_deployed_index_empty_call(): with mock.patch.object( type(client.transport.mutate_deployed_index), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.mutate_deployed_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3631,6 +4298,9 @@ def test_mutate_deployed_index_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.mutate_deployed_index), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.mutate_deployed_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3639,6 +4309,50 @@ def test_mutate_deployed_index_non_empty_request_with_auto_populated_field(): ) +def test_mutate_deployed_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.mutate_deployed_index + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.mutate_deployed_index + ] = mock_rpc + request = {} + client.mutate_deployed_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.mutate_deployed_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_mutate_deployed_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3662,6 +4376,56 @@ async def test_mutate_deployed_index_empty_call_async(): assert args[0] == index_endpoint_service.MutateDeployedIndexRequest() +@pytest.mark.asyncio +async def test_mutate_deployed_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.mutate_deployed_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.mutate_deployed_index + ] = mock_object + + request = {} + await client.mutate_deployed_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.mutate_deployed_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_mutate_deployed_index_async( transport: str = "grpc_asyncio", @@ -4039,6 +4803,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_index_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_index_endpoint + ] = mock_rpc + + request = {} + client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_index_endpoint_rest_required_fields( request_type=index_endpoint_service.CreateIndexEndpointRequest, ): @@ -4329,6 +5138,46 @@ def test_get_index_endpoint_rest(request_type): assert response.public_endpoint_domain_name == "public_endpoint_domain_name_value" +def test_get_index_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_index_endpoint in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_index_endpoint + ] = mock_rpc + + request = {} + client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_index_endpoint_rest_required_fields( request_type=index_endpoint_service.GetIndexEndpointRequest, ): @@ -4600,6 +5449,46 @@ def test_list_index_endpoints_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_index_endpoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_index_endpoints in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_index_endpoints + ] = mock_rpc + + request = {} + client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_index_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_index_endpoints_rest_required_fields( request_type=index_endpoint_service.ListIndexEndpointsRequest, ): @@ -5111,6 +6000,47 @@ def get_message_fields(field): assert response.public_endpoint_domain_name == "public_endpoint_domain_name_value" +def test_update_index_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_index_endpoint + ] = mock_rpc + + request = {} + client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_index_endpoint_rest_required_fields( request_type=index_endpoint_service.UpdateIndexEndpointRequest, ): @@ -5390,6 +6320,51 @@ def test_delete_index_endpoint_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_index_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_index_endpoint + ] = mock_rpc + + request = {} + client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_index_endpoint_rest_required_fields( request_type=index_endpoint_service.DeleteIndexEndpointRequest, ): @@ -5654,6 +6629,46 @@ def test_deploy_index_rest(request_type): assert response.operation.name == "operations/spam" +def test_deploy_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.deploy_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.deploy_index] = mock_rpc + + request = {} + client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.deploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_deploy_index_rest_required_fields( request_type=index_endpoint_service.DeployIndexRequest, ): @@ -5930,6 +6945,46 @@ def test_undeploy_index_rest(request_type): assert response.operation.name == "operations/spam" +def test_undeploy_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.undeploy_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.undeploy_index] = mock_rpc + + request = {} + client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.undeploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_undeploy_index_rest_required_fields( request_type=index_endpoint_service.UndeployIndexRequest, ): @@ -6323,6 +7378,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_mutate_deployed_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.mutate_deployed_index + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.mutate_deployed_index + ] = mock_rpc + + request = {} + client.mutate_deployed_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.mutate_deployed_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_mutate_deployed_index_rest_required_fields( request_type=index_endpoint_service.MutateDeployedIndexRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_index_service.py b/tests/unit/gapic/aiplatform_v1/test_index_service.py index e3ca4b61a8..3e946272f4 100644 --- a/tests/unit/gapic/aiplatform_v1/test_index_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_index_service.py @@ -1157,6 +1157,9 @@ def test_create_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1180,6 +1183,9 @@ def test_create_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1188,6 +1194,45 @@ def test_create_index_non_empty_request_with_auto_populated_field(): ) +def test_create_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_index] = mock_rpc + request = {} + client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1209,6 +1254,56 @@ async def test_create_index_empty_call_async(): assert args[0] == index_service.CreateIndexRequest() +@pytest.mark.asyncio +async def test_create_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_index + ] = mock_object + + request = {} + await client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_async( transport: str = "grpc_asyncio", request_type=index_service.CreateIndexRequest @@ -1454,6 +1549,9 @@ def test_get_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1477,6 +1575,9 @@ def test_get_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1485,6 +1586,41 @@ def test_get_index_non_empty_request_with_auto_populated_field(): ) +def test_get_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_index] = mock_rpc + request = {} + client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1513,6 +1649,50 @@ async def test_get_index_empty_call_async(): assert args[0] == index_service.GetIndexRequest() +@pytest.mark.asyncio +async def test_get_index_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_index + ] = mock_object + + request = {} + await client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_async( transport: str = "grpc_asyncio", request_type=index_service.GetIndexRequest @@ -1747,6 +1927,9 @@ def test_list_indexes_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_indexes() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1772,6 +1955,9 @@ def test_list_indexes_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_indexes(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1782,6 +1968,41 @@ def test_list_indexes_non_empty_request_with_auto_populated_field(): ) +def test_list_indexes_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_indexes in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_indexes] = mock_rpc + request = {} + client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_indexes_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1805,6 +2026,52 @@ async def test_list_indexes_empty_call_async(): assert args[0] == index_service.ListIndexesRequest() +@pytest.mark.asyncio +async def test_list_indexes_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_indexes + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_indexes + ] = mock_object + + request = {} + await client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_indexes_async( transport: str = "grpc_asyncio", request_type=index_service.ListIndexesRequest @@ -2220,6 +2487,9 @@ def test_update_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2241,12 +2511,54 @@ def test_update_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_service.UpdateIndexRequest() +def test_update_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_index] = mock_rpc + request = {} + client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2268,6 +2580,56 @@ async def test_update_index_empty_call_async(): assert args[0] == index_service.UpdateIndexRequest() +@pytest.mark.asyncio +async def test_update_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_index + ] = mock_object + + request = {} + await client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_index_async( transport: str = "grpc_asyncio", request_type=index_service.UpdateIndexRequest @@ -2500,6 +2862,9 @@ def test_delete_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2523,6 +2888,9 @@ def test_delete_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2531,6 +2899,45 @@ def test_delete_index_non_empty_request_with_auto_populated_field(): ) +def test_delete_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_index] = mock_rpc + request = {} + client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2552,6 +2959,56 @@ async def test_delete_index_empty_call_async(): assert args[0] == index_service.DeleteIndexRequest() +@pytest.mark.asyncio +async def test_delete_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_index + ] = mock_object + + request = {} + await client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_async( transport: str = "grpc_asyncio", request_type=index_service.DeleteIndexRequest @@ -2778,6 +3235,9 @@ def test_upsert_datapoints_empty_call(): with mock.patch.object( type(client.transport.upsert_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upsert_datapoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2803,6 +3263,9 @@ def test_upsert_datapoints_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.upsert_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upsert_datapoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2811,6 +3274,43 @@ def test_upsert_datapoints_non_empty_request_with_auto_populated_field(): ) +def test_upsert_datapoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upsert_datapoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.upsert_datapoints + ] = mock_rpc + request = {} + client.upsert_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.upsert_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_upsert_datapoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2834,6 +3334,52 @@ async def test_upsert_datapoints_empty_call_async(): assert args[0] == index_service.UpsertDatapointsRequest() +@pytest.mark.asyncio +async def test_upsert_datapoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.upsert_datapoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.upsert_datapoints + ] = mock_object + + request = {} + await client.upsert_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.upsert_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_upsert_datapoints_async( transport: str = "grpc_asyncio", request_type=index_service.UpsertDatapointsRequest @@ -2984,6 +3530,9 @@ def test_remove_datapoints_empty_call(): with mock.patch.object( type(client.transport.remove_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.remove_datapoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3009,6 +3558,9 @@ def test_remove_datapoints_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.remove_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.remove_datapoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3017,6 +3569,43 @@ def test_remove_datapoints_non_empty_request_with_auto_populated_field(): ) +def test_remove_datapoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.remove_datapoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.remove_datapoints + ] = mock_rpc + request = {} + client.remove_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.remove_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_remove_datapoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3040,6 +3629,52 @@ async def test_remove_datapoints_empty_call_async(): assert args[0] == index_service.RemoveDatapointsRequest() +@pytest.mark.asyncio +async def test_remove_datapoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.remove_datapoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.remove_datapoints + ] = mock_object + + request = {} + await client.remove_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.remove_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_remove_datapoints_async( transport: str = "grpc_asyncio", request_type=index_service.RemoveDatapointsRequest @@ -3273,6 +3908,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_index] = mock_rpc + + request = {} + client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_index_rest_required_fields( request_type=index_service.CreateIndexRequest, ): @@ -3557,6 +4232,42 @@ def test_get_index_rest(request_type): assert response.index_update_method == index.Index.IndexUpdateMethod.BATCH_UPDATE +def test_get_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_index] = mock_rpc + + request = {} + client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_index_rest_required_fields(request_type=index_service.GetIndexRequest): transport_class = transports.IndexServiceRestTransport @@ -3817,6 +4528,42 @@ def test_list_indexes_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_indexes_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_indexes in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_indexes] = mock_rpc + + request = {} + client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_indexes_rest_required_fields( request_type=index_service.ListIndexesRequest, ): @@ -4249,6 +4996,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_index] = mock_rpc + + request = {} + client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_index_rest_required_fields( request_type=index_service.UpdateIndexRequest, ): @@ -4512,6 +5299,46 @@ def test_delete_index_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_index] = mock_rpc + + request = {} + client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_index_rest_required_fields( request_type=index_service.DeleteIndexRequest, ): @@ -4772,6 +5599,44 @@ def test_upsert_datapoints_rest(request_type): assert isinstance(response, index_service.UpsertDatapointsResponse) +def test_upsert_datapoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upsert_datapoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.upsert_datapoints + ] = mock_rpc + + request = {} + client.upsert_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.upsert_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_upsert_datapoints_rest_required_fields( request_type=index_service.UpsertDatapointsRequest, ): @@ -4980,6 +5845,44 @@ def test_remove_datapoints_rest(request_type): assert isinstance(response, index_service.RemoveDatapointsResponse) +def test_remove_datapoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.remove_datapoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.remove_datapoints + ] = mock_rpc + + request = {} + client.remove_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.remove_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_remove_datapoints_rest_required_fields( request_type=index_service.RemoveDatapointsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index 8a6ad0f417..d5619fa394 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -1186,6 +1186,9 @@ def test_create_custom_job_empty_call(): with mock.patch.object( type(client.transport.create_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1211,6 +1214,9 @@ def test_create_custom_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_custom_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1219,6 +1225,43 @@ def test_create_custom_job_non_empty_request_with_auto_populated_field(): ) +def test_create_custom_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_custom_job + ] = mock_rpc + request = {} + client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_custom_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1246,6 +1289,52 @@ async def test_create_custom_job_empty_call_async(): assert args[0] == job_service.CreateCustomJobRequest() +@pytest.mark.asyncio +async def test_create_custom_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_custom_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_custom_job + ] = mock_object + + request = {} + await client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest @@ -1502,6 +1591,9 @@ def test_get_custom_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1525,6 +1617,9 @@ def test_get_custom_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_custom_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1533,6 +1628,41 @@ def test_get_custom_job_non_empty_request_with_auto_populated_field(): ) +def test_get_custom_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_custom_job] = mock_rpc + request = {} + client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_custom_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1558,6 +1688,52 @@ async def test_get_custom_job_empty_call_async(): assert args[0] == job_service.GetCustomJobRequest() +@pytest.mark.asyncio +async def test_get_custom_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_custom_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_custom_job + ] = mock_object + + request = {} + await client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest @@ -1790,6 +1966,9 @@ def test_list_custom_jobs_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_custom_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1815,6 +1994,9 @@ def test_list_custom_jobs_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_custom_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1825,6 +2007,43 @@ def test_list_custom_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_custom_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_custom_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_custom_jobs + ] = mock_rpc + request = {} + client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_custom_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_custom_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1848,6 +2067,52 @@ async def test_list_custom_jobs_empty_call_async(): assert args[0] == job_service.ListCustomJobsRequest() +@pytest.mark.asyncio +async def test_list_custom_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_custom_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_custom_jobs + ] = mock_object + + request = {} + await client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_custom_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_custom_jobs_async( transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest @@ -2267,6 +2532,9 @@ def test_delete_custom_job_empty_call(): with mock.patch.object( type(client.transport.delete_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2292,6 +2560,9 @@ def test_delete_custom_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_custom_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2300,6 +2571,47 @@ def test_delete_custom_job_non_empty_request_with_auto_populated_field(): ) +def test_delete_custom_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_custom_job + ] = mock_rpc + request = {} + client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_custom_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2323,6 +2635,56 @@ async def test_delete_custom_job_empty_call_async(): assert args[0] == job_service.DeleteCustomJobRequest() +@pytest.mark.asyncio +async def test_delete_custom_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_custom_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_custom_job + ] = mock_object + + request = {} + await client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest @@ -2559,6 +2921,9 @@ def test_cancel_custom_job_empty_call(): with mock.patch.object( type(client.transport.cancel_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2584,6 +2949,9 @@ def test_cancel_custom_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_custom_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2592,6 +2960,43 @@ def test_cancel_custom_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_custom_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_custom_job + ] = mock_rpc + request = {} + client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_custom_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2613,6 +3018,52 @@ async def test_cancel_custom_job_empty_call_async(): assert args[0] == job_service.CancelCustomJobRequest() +@pytest.mark.asyncio +async def test_cancel_custom_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_custom_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_custom_job + ] = mock_object + + request = {} + await client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest @@ -2862,6 +3313,9 @@ def test_create_data_labeling_job_empty_call(): with mock.patch.object( type(client.transport.create_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2887,6 +3341,9 @@ def test_create_data_labeling_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_data_labeling_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2895,6 +3352,46 @@ def test_create_data_labeling_job_non_empty_request_with_auto_populated_field(): ) +def test_create_data_labeling_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_data_labeling_job + ] = mock_rpc + request = {} + client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_data_labeling_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2928,6 +3425,52 @@ async def test_create_data_labeling_job_empty_call_async(): assert args[0] == job_service.CreateDataLabelingJobRequest() +@pytest.mark.asyncio +async def test_create_data_labeling_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_data_labeling_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_data_labeling_job + ] = mock_object + + request = {} + await client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -3213,6 +3756,9 @@ def test_get_data_labeling_job_empty_call(): with mock.patch.object( type(client.transport.get_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3238,6 +3784,9 @@ def test_get_data_labeling_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_data_labeling_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3246,6 +3795,46 @@ def test_get_data_labeling_job_non_empty_request_with_auto_populated_field(): ) +def test_get_data_labeling_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_data_labeling_job + ] = mock_rpc + request = {} + client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_data_labeling_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3280,11 +3869,57 @@ async def test_get_data_labeling_job_empty_call_async(): @pytest.mark.asyncio -async def test_get_data_labeling_job_async( - transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +async def test_get_data_labeling_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = JobServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_data_labeling_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_data_labeling_job + ] = mock_object + + request = {} + await client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +): + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3537,6 +4172,9 @@ def test_list_data_labeling_jobs_empty_call(): with mock.patch.object( type(client.transport.list_data_labeling_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_data_labeling_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3565,6 +4203,9 @@ def test_list_data_labeling_jobs_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_data_labeling_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_data_labeling_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3576,6 +4217,46 @@ def test_list_data_labeling_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_data_labeling_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_data_labeling_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_data_labeling_jobs + ] = mock_rpc + request = {} + client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_data_labeling_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_data_labeling_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3601,6 +4282,52 @@ async def test_list_data_labeling_jobs_empty_call_async(): assert args[0] == job_service.ListDataLabelingJobsRequest() +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_data_labeling_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_data_labeling_jobs + ] = mock_object + + request = {} + await client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_data_labeling_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async( transport: str = "grpc_asyncio", @@ -4039,6 +4766,9 @@ def test_delete_data_labeling_job_empty_call(): with mock.patch.object( type(client.transport.delete_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4064,6 +4794,9 @@ def test_delete_data_labeling_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_data_labeling_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4072,6 +4805,50 @@ def test_delete_data_labeling_job_non_empty_request_with_auto_populated_field(): ) +def test_delete_data_labeling_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_data_labeling_job + ] = mock_rpc + request = {} + client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_data_labeling_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4095,6 +4872,56 @@ async def test_delete_data_labeling_job_empty_call_async(): assert args[0] == job_service.DeleteDataLabelingJobRequest() +@pytest.mark.asyncio +async def test_delete_data_labeling_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_data_labeling_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_data_labeling_job + ] = mock_object + + request = {} + await client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -4332,6 +5159,9 @@ def test_cancel_data_labeling_job_empty_call(): with mock.patch.object( type(client.transport.cancel_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4357,6 +5187,9 @@ def test_cancel_data_labeling_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_data_labeling_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4365,6 +5198,46 @@ def test_cancel_data_labeling_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_data_labeling_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_data_labeling_job + ] = mock_rpc + request = {} + client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_data_labeling_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4386,6 +5259,52 @@ async def test_cancel_data_labeling_job_empty_call_async(): assert args[0] == job_service.CancelDataLabelingJobRequest() +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_data_labeling_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_data_labeling_job + ] = mock_object + + request = {} + await client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -4630,6 +5549,9 @@ def test_create_hyperparameter_tuning_job_empty_call(): with mock.patch.object( type(client.transport.create_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4655,6 +5577,9 @@ def test_create_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.create_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_hyperparameter_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4663,6 +5588,46 @@ def test_create_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ ) +def test_create_hyperparameter_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_hyperparameter_tuning_job + ] = mock_rpc + request = {} + client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4693,6 +5658,52 @@ async def test_create_hyperparameter_tuning_job_empty_call_async(): assert args[0] == job_service.CreateHyperparameterTuningJobRequest() +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_hyperparameter_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_hyperparameter_tuning_job + ] = mock_object + + request = {} + await client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -4978,6 +5989,9 @@ def test_get_hyperparameter_tuning_job_empty_call(): with mock.patch.object( type(client.transport.get_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5003,6 +6017,9 @@ def test_get_hyperparameter_tuning_job_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.get_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_hyperparameter_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5011,6 +6028,46 @@ def test_get_hyperparameter_tuning_job_non_empty_request_with_auto_populated_fie ) +def test_get_hyperparameter_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_hyperparameter_tuning_job + ] = mock_rpc + request = {} + client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5041,6 +6098,52 @@ async def test_get_hyperparameter_tuning_job_empty_call_async(): assert args[0] == job_service.GetHyperparameterTuningJobRequest() +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_hyperparameter_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_hyperparameter_tuning_job + ] = mock_object + + request = {} + await client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -5294,6 +6397,9 @@ def test_list_hyperparameter_tuning_jobs_empty_call(): with mock.patch.object( type(client.transport.list_hyperparameter_tuning_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_hyperparameter_tuning_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5321,6 +6427,9 @@ def test_list_hyperparameter_tuning_jobs_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.list_hyperparameter_tuning_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_hyperparameter_tuning_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5331,6 +6440,46 @@ def test_list_hyperparameter_tuning_jobs_non_empty_request_with_auto_populated_f ) +def test_list_hyperparameter_tuning_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_hyperparameter_tuning_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_hyperparameter_tuning_jobs + ] = mock_rpc + request = {} + client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_hyperparameter_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5357,17 +6506,63 @@ async def test_list_hyperparameter_tuning_jobs_empty_call_async(): @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async( +async def test_list_hyperparameter_tuning_jobs_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=job_service.ListHyperparameterTuningJobsRequest, ): - client = JobServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_hyperparameter_tuning_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_hyperparameter_tuning_jobs + ] = mock_object + + request = {} + await client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_hyperparameter_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. @@ -5800,6 +6995,9 @@ def test_delete_hyperparameter_tuning_job_empty_call(): with mock.patch.object( type(client.transport.delete_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5825,6 +7023,9 @@ def test_delete_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.delete_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_hyperparameter_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5833,6 +7034,50 @@ def test_delete_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ ) +def test_delete_hyperparameter_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_hyperparameter_tuning_job + ] = mock_rpc + request = {} + client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5856,6 +7101,56 @@ async def test_delete_hyperparameter_tuning_job_empty_call_async(): assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_hyperparameter_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_hyperparameter_tuning_job + ] = mock_object + + request = {} + await client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -6093,6 +7388,9 @@ def test_cancel_hyperparameter_tuning_job_empty_call(): with mock.patch.object( type(client.transport.cancel_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6118,6 +7416,9 @@ def test_cancel_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.cancel_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_hyperparameter_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6126,6 +7427,46 @@ def test_cancel_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ ) +def test_cancel_hyperparameter_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_hyperparameter_tuning_job + ] = mock_rpc + request = {} + client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6147,6 +7488,52 @@ async def test_cancel_hyperparameter_tuning_job_empty_call_async(): assert args[0] == job_service.CancelHyperparameterTuningJobRequest() +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_hyperparameter_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_hyperparameter_tuning_job + ] = mock_object + + request = {} + await client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -6383,6 +7770,9 @@ def test_create_nas_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_nas_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6406,6 +7796,9 @@ def test_create_nas_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_nas_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6414,6 +7807,41 @@ def test_create_nas_job_non_empty_request_with_auto_populated_field(): ) +def test_create_nas_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_nas_job] = mock_rpc + request = {} + client.create_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_nas_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6440,6 +7868,52 @@ async def test_create_nas_job_empty_call_async(): assert args[0] == job_service.CreateNasJobRequest() +@pytest.mark.asyncio +async def test_create_nas_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_nas_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_nas_job + ] = mock_object + + request = {} + await client.create_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_nas_job_async( transport: str = "grpc_asyncio", request_type=job_service.CreateNasJobRequest @@ -6686,6 +8160,9 @@ def test_get_nas_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_nas_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6709,6 +8186,9 @@ def test_get_nas_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_nas_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6717,6 +8197,41 @@ def test_get_nas_job_non_empty_request_with_auto_populated_field(): ) +def test_get_nas_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_nas_job] = mock_rpc + request = {} + client.get_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_nas_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6743,6 +8258,52 @@ async def test_get_nas_job_empty_call_async(): assert args[0] == job_service.GetNasJobRequest() +@pytest.mark.asyncio +async def test_get_nas_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_nas_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_nas_job + ] = mock_object + + request = {} + await client.get_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_nas_job_async( transport: str = "grpc_asyncio", request_type=job_service.GetNasJobRequest @@ -6973,6 +8534,9 @@ def test_list_nas_jobs_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_nas_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_nas_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6998,6 +8562,9 @@ def test_list_nas_jobs_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_nas_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_nas_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7008,6 +8575,41 @@ def test_list_nas_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_nas_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_nas_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_nas_jobs] = mock_rpc + request = {} + client.list_nas_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_nas_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_nas_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7031,6 +8633,52 @@ async def test_list_nas_jobs_empty_call_async(): assert args[0] == job_service.ListNasJobsRequest() +@pytest.mark.asyncio +async def test_list_nas_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_nas_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_nas_jobs + ] = mock_object + + request = {} + await client.list_nas_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_nas_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_nas_jobs_async( transport: str = "grpc_asyncio", request_type=job_service.ListNasJobsRequest @@ -7446,6 +9094,9 @@ def test_delete_nas_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_nas_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7469,6 +9120,9 @@ def test_delete_nas_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_nas_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7477,6 +9131,45 @@ def test_delete_nas_job_non_empty_request_with_auto_populated_field(): ) +def test_delete_nas_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_nas_job] = mock_rpc + request = {} + client.delete_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_nas_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7499,21 +9192,71 @@ async def test_delete_nas_job_empty_call_async(): @pytest.mark.asyncio -async def test_delete_nas_job_async( - transport: str = "grpc_asyncio", request_type=job_service.DeleteNasJobRequest +async def test_delete_nas_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = JobServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_nas_job), "__call__") as call: - # Designate an appropriate return value for the call. + # Ensure method has been cached + assert ( + client._client._transport.delete_nas_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_nas_job + ] = mock_object + + request = {} + await client.delete_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_nas_job_async( + transport: str = "grpc_asyncio", request_type=job_service.DeleteNasJobRequest +): + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_nas_job), "__call__") as call: + # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") ) @@ -7720,6 +9463,9 @@ def test_cancel_nas_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.cancel_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_nas_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7743,6 +9489,9 @@ def test_cancel_nas_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.cancel_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_nas_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7751,6 +9500,41 @@ def test_cancel_nas_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_nas_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.cancel_nas_job] = mock_rpc + request = {} + client.cancel_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_nas_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7770,6 +9554,52 @@ async def test_cancel_nas_job_empty_call_async(): assert args[0] == job_service.CancelNasJobRequest() +@pytest.mark.asyncio +async def test_cancel_nas_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_nas_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_nas_job + ] = mock_object + + request = {} + await client.cancel_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_nas_job_async( transport: str = "grpc_asyncio", request_type=job_service.CancelNasJobRequest @@ -7995,6 +9825,9 @@ def test_get_nas_trial_detail_empty_call(): with mock.patch.object( type(client.transport.get_nas_trial_detail), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_nas_trial_detail() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8020,6 +9853,9 @@ def test_get_nas_trial_detail_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_nas_trial_detail), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_nas_trial_detail(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8028,6 +9864,45 @@ def test_get_nas_trial_detail_non_empty_request_with_auto_populated_field(): ) +def test_get_nas_trial_detail_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_nas_trial_detail in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_nas_trial_detail + ] = mock_rpc + request = {} + client.get_nas_trial_detail(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_nas_trial_detail(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_nas_trial_detail_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8054,6 +9929,52 @@ async def test_get_nas_trial_detail_empty_call_async(): assert args[0] == job_service.GetNasTrialDetailRequest() +@pytest.mark.asyncio +async def test_get_nas_trial_detail_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_nas_trial_detail + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_nas_trial_detail + ] = mock_object + + request = {} + await client.get_nas_trial_detail(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_nas_trial_detail(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_nas_trial_detail_async( transport: str = "grpc_asyncio", request_type=job_service.GetNasTrialDetailRequest @@ -8298,6 +10219,9 @@ def test_list_nas_trial_details_empty_call(): with mock.patch.object( type(client.transport.list_nas_trial_details), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_nas_trial_details() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8324,6 +10248,9 @@ def test_list_nas_trial_details_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_nas_trial_details), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_nas_trial_details(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8333,6 +10260,46 @@ def test_list_nas_trial_details_non_empty_request_with_auto_populated_field(): ) +def test_list_nas_trial_details_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_nas_trial_details + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_nas_trial_details + ] = mock_rpc + request = {} + client.list_nas_trial_details(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_nas_trial_details(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_nas_trial_details_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8358,6 +10325,52 @@ async def test_list_nas_trial_details_empty_call_async(): assert args[0] == job_service.ListNasTrialDetailsRequest() +@pytest.mark.asyncio +async def test_list_nas_trial_details_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_nas_trial_details + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_nas_trial_details + ] = mock_object + + request = {} + await client.list_nas_trial_details(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_nas_trial_details(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_nas_trial_details_async( transport: str = "grpc_asyncio", request_type=job_service.ListNasTrialDetailsRequest @@ -8812,6 +10825,9 @@ def test_create_batch_prediction_job_empty_call(): with mock.patch.object( type(client.transport.create_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8837,6 +10853,9 @@ def test_create_batch_prediction_job_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.create_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_batch_prediction_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8845,6 +10864,46 @@ def test_create_batch_prediction_job_non_empty_request_with_auto_populated_field ) +def test_create_batch_prediction_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_batch_prediction_job + ] = mock_rpc + request = {} + client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_batch_prediction_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8877,6 +10936,52 @@ async def test_create_batch_prediction_job_empty_call_async(): assert args[0] == job_service.CreateBatchPredictionJobRequest() +@pytest.mark.asyncio +async def test_create_batch_prediction_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_batch_prediction_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_batch_prediction_job + ] = mock_object + + request = {} + await client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -9166,6 +11271,9 @@ def test_get_batch_prediction_job_empty_call(): with mock.patch.object( type(client.transport.get_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9191,6 +11299,9 @@ def test_get_batch_prediction_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_batch_prediction_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9199,6 +11310,46 @@ def test_get_batch_prediction_job_non_empty_request_with_auto_populated_field(): ) +def test_get_batch_prediction_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_batch_prediction_job + ] = mock_rpc + request = {} + client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_batch_prediction_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9231,6 +11382,52 @@ async def test_get_batch_prediction_job_empty_call_async(): assert args[0] == job_service.GetBatchPredictionJobRequest() +@pytest.mark.asyncio +async def test_get_batch_prediction_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_batch_prediction_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_batch_prediction_job + ] = mock_object + + request = {} + await client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -9488,6 +11685,9 @@ def test_list_batch_prediction_jobs_empty_call(): with mock.patch.object( type(client.transport.list_batch_prediction_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_batch_prediction_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9515,6 +11715,9 @@ def test_list_batch_prediction_jobs_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.list_batch_prediction_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_batch_prediction_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9525,6 +11728,46 @@ def test_list_batch_prediction_jobs_non_empty_request_with_auto_populated_field( ) +def test_list_batch_prediction_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_batch_prediction_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_batch_prediction_jobs + ] = mock_rpc + request = {} + client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_batch_prediction_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9550,6 +11793,52 @@ async def test_list_batch_prediction_jobs_empty_call_async(): assert args[0] == job_service.ListBatchPredictionJobsRequest() +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_batch_prediction_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_batch_prediction_jobs + ] = mock_object + + request = {} + await client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_batch_prediction_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async( transport: str = "grpc_asyncio", @@ -9992,6 +12281,9 @@ def test_delete_batch_prediction_job_empty_call(): with mock.patch.object( type(client.transport.delete_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10017,6 +12309,9 @@ def test_delete_batch_prediction_job_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.delete_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_batch_prediction_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10025,27 +12320,121 @@ def test_delete_batch_prediction_job_non_empty_request_with_auto_populated_field ) -@pytest.mark.asyncio -async def test_delete_batch_prediction_job_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = JobServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) +def test_delete_batch_prediction_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_batch_prediction_job + ] = mock_rpc + request = {} + client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_batch_prediction_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_batch_prediction_job + ] = mock_object + + request = {} + await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_batch_prediction_job(request) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - response = await client.delete_batch_prediction_job() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == job_service.DeleteBatchPredictionJobRequest() + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio @@ -10285,6 +12674,9 @@ def test_cancel_batch_prediction_job_empty_call(): with mock.patch.object( type(client.transport.cancel_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10310,6 +12702,9 @@ def test_cancel_batch_prediction_job_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.cancel_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_batch_prediction_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10318,6 +12713,46 @@ def test_cancel_batch_prediction_job_non_empty_request_with_auto_populated_field ) +def test_cancel_batch_prediction_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_batch_prediction_job + ] = mock_rpc + request = {} + client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_batch_prediction_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10339,6 +12774,52 @@ async def test_cancel_batch_prediction_job_empty_call_async(): assert args[0] == job_service.CancelBatchPredictionJobRequest() +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_batch_prediction_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_batch_prediction_job + ] = mock_object + + request = {} + await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -10592,6 +13073,9 @@ def test_create_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.create_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10617,6 +13101,9 @@ def test_create_model_deployment_monitoring_job_non_empty_request_with_auto_popu with mock.patch.object( type(client.transport.create_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10625,6 +13112,46 @@ def test_create_model_deployment_monitoring_job_non_empty_request_with_auto_popu ) +def test_create_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10657,6 +13184,52 @@ async def test_create_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.CreateModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_create_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -10955,6 +13528,9 @@ def test_search_model_deployment_monitoring_stats_anomalies_empty_call(): type(client.transport.search_model_deployment_monitoring_stats_anomalies), "__call__", ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_model_deployment_monitoring_stats_anomalies() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10987,6 +13563,9 @@ def test_search_model_deployment_monitoring_stats_anomalies_non_empty_request_wi type(client.transport.search_model_deployment_monitoring_stats_anomalies), "__call__", ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_model_deployment_monitoring_stats_anomalies(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11000,6 +13579,46 @@ def test_search_model_deployment_monitoring_stats_anomalies_non_empty_request_wi ) +def test_search_model_deployment_monitoring_stats_anomalies_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_model_deployment_monitoring_stats_anomalies + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_model_deployment_monitoring_stats_anomalies + ] = mock_rpc + request = {} + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11029,6 +13648,52 @@ async def test_search_model_deployment_monitoring_stats_anomalies_empty_call_asy ) +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_model_deployment_monitoring_stats_anomalies + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_model_deployment_monitoring_stats_anomalies + ] = mock_object + + request = {} + await client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_async( transport: str = "grpc_asyncio", @@ -11536,6 +14201,9 @@ def test_get_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.get_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11561,6 +14229,9 @@ def test_get_model_deployment_monitoring_job_non_empty_request_with_auto_populat with mock.patch.object( type(client.transport.get_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11569,6 +14240,46 @@ def test_get_model_deployment_monitoring_job_non_empty_request_with_auto_populat ) +def test_get_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11601,6 +14312,52 @@ async def test_get_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.GetModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_get_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -11869,6 +14626,9 @@ def test_list_model_deployment_monitoring_jobs_empty_call(): with mock.patch.object( type(client.transport.list_model_deployment_monitoring_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_deployment_monitoring_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11896,6 +14656,9 @@ def test_list_model_deployment_monitoring_jobs_non_empty_request_with_auto_popul with mock.patch.object( type(client.transport.list_model_deployment_monitoring_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_deployment_monitoring_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11906,6 +14669,46 @@ def test_list_model_deployment_monitoring_jobs_non_empty_request_with_auto_popul ) +def test_list_model_deployment_monitoring_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_deployment_monitoring_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_deployment_monitoring_jobs + ] = mock_rpc + request = {} + client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_deployment_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11931,6 +14734,52 @@ async def test_list_model_deployment_monitoring_jobs_empty_call_async(): assert args[0] == job_service.ListModelDeploymentMonitoringJobsRequest() +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_deployment_monitoring_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_deployment_monitoring_jobs + ] = mock_object + + request = {} + await client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_deployment_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_async( transport: str = "grpc_asyncio", @@ -12375,6 +15224,9 @@ def test_update_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.update_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -12398,12 +15250,59 @@ def test_update_model_deployment_monitoring_job_non_empty_request_with_auto_popu with mock.patch.object( type(client.transport.update_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.UpdateModelDeploymentMonitoringJobRequest() +def test_update_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -12427,6 +15326,56 @@ async def test_update_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.UpdateModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_update_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -12686,6 +15635,9 @@ def test_delete_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.delete_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -12711,6 +15663,9 @@ def test_delete_model_deployment_monitoring_job_non_empty_request_with_auto_popu with mock.patch.object( type(client.transport.delete_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -12719,6 +15674,50 @@ def test_delete_model_deployment_monitoring_job_non_empty_request_with_auto_popu ) +def test_delete_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -12742,6 +15741,56 @@ async def test_delete_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.DeleteModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_delete_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -12979,6 +16028,9 @@ def test_pause_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.pause_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.pause_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -13004,6 +16056,9 @@ def test_pause_model_deployment_monitoring_job_non_empty_request_with_auto_popul with mock.patch.object( type(client.transport.pause_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.pause_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -13012,6 +16067,46 @@ def test_pause_model_deployment_monitoring_job_non_empty_request_with_auto_popul ) +def test_pause_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.pause_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.pause_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.pause_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.pause_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_pause_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -13033,6 +16128,52 @@ async def test_pause_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.PauseModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_pause_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.pause_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.pause_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.pause_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.pause_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_pause_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -13264,6 +16405,9 @@ def test_resume_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.resume_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.resume_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -13289,6 +16433,9 @@ def test_resume_model_deployment_monitoring_job_non_empty_request_with_auto_popu with mock.patch.object( type(client.transport.resume_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.resume_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -13297,6 +16444,46 @@ def test_resume_model_deployment_monitoring_job_non_empty_request_with_auto_popu ) +def test_resume_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.resume_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.resume_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.resume_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.resume_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_resume_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -13318,6 +16505,52 @@ async def test_resume_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.ResumeModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_resume_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.resume_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.resume_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.resume_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.resume_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_resume_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -13692,6 +16925,44 @@ def get_message_fields(field): assert response.state == job_state.JobState.JOB_STATE_QUEUED +def test_create_custom_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_custom_job + ] = mock_rpc + + request = {} + client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_custom_job_rest_required_fields( request_type=job_service.CreateCustomJobRequest, ): @@ -13973,6 +17244,42 @@ def test_get_custom_job_rest(request_type): assert response.state == job_state.JobState.JOB_STATE_QUEUED +def test_get_custom_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_custom_job] = mock_rpc + + request = {} + client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_custom_job_rest_required_fields( request_type=job_service.GetCustomJobRequest, ): @@ -14239,6 +17546,44 @@ def test_list_custom_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_custom_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_custom_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_custom_jobs + ] = mock_rpc + + request = {} + client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_custom_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_custom_jobs_rest_required_fields( request_type=job_service.ListCustomJobsRequest, ): @@ -14576,6 +17921,48 @@ def test_delete_custom_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_custom_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_custom_job + ] = mock_rpc + + request = {} + client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_custom_job_rest_required_fields( request_type=job_service.DeleteCustomJobRequest, ): @@ -14836,6 +18223,44 @@ def test_cancel_custom_job_rest(request_type): assert response is None +def test_cancel_custom_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_custom_job + ] = mock_rpc + + request = {} + client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_custom_job_rest_required_fields( request_type=job_service.CancelCustomJobRequest, ): @@ -15227,6 +18652,47 @@ def get_message_fields(field): assert response.specialist_pools == ["specialist_pools_value"] +def test_create_data_labeling_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_data_labeling_job + ] = mock_rpc + + request = {} + client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_data_labeling_job_rest_required_fields( request_type=job_service.CreateDataLabelingJobRequest, ): @@ -15510,17 +18976,58 @@ def test_get_data_labeling_job_rest(request_type): req.return_value = response_value response = client.get_data_labeling_job(request) - # Establish that the response is the type that we expect. - assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.datasets == ["datasets_value"] - assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == "inputs_schema_uri_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + # Establish that the response is the type that we expect. + assert isinstance(response, data_labeling_job.DataLabelingJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.datasets == ["datasets_value"] + assert response.labeler_count == 1375 + assert response.instruction_uri == "instruction_uri_value" + assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.state == job_state.JobState.JOB_STATE_QUEUED + assert response.labeling_progress == 1810 + assert response.specialist_pools == ["specialist_pools_value"] + + +def test_get_data_labeling_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_data_labeling_job + ] = mock_rpc + + request = {} + client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_get_data_labeling_job_rest_required_fields( @@ -15794,6 +19301,47 @@ def test_list_data_labeling_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_data_labeling_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_data_labeling_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_data_labeling_jobs + ] = mock_rpc + + request = {} + client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_data_labeling_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_data_labeling_jobs_rest_required_fields( request_type=job_service.ListDataLabelingJobsRequest, ): @@ -16136,6 +19684,51 @@ def test_delete_data_labeling_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_data_labeling_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_data_labeling_job + ] = mock_rpc + + request = {} + client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_data_labeling_job_rest_required_fields( request_type=job_service.DeleteDataLabelingJobRequest, ): @@ -16401,6 +19994,47 @@ def test_cancel_data_labeling_job_rest(request_type): assert response is None +def test_cancel_data_labeling_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_data_labeling_job + ] = mock_rpc + + request = {} + client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_data_labeling_job_rest_required_fields( request_type=job_service.CancelDataLabelingJobRequest, ): @@ -16929,6 +20563,47 @@ def get_message_fields(field): assert response.state == job_state.JobState.JOB_STATE_QUEUED +def test_create_hyperparameter_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_hyperparameter_tuning_job + ] = mock_rpc + + request = {} + client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_hyperparameter_tuning_job_rest_required_fields( request_type=job_service.CreateHyperparameterTuningJobRequest, ): @@ -17234,6 +20909,47 @@ def test_get_hyperparameter_tuning_job_rest(request_type): assert response.state == job_state.JobState.JOB_STATE_QUEUED +def test_get_hyperparameter_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_hyperparameter_tuning_job + ] = mock_rpc + + request = {} + client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_hyperparameter_tuning_job_rest_required_fields( request_type=job_service.GetHyperparameterTuningJobRequest, ): @@ -17513,6 +21229,47 @@ def test_list_hyperparameter_tuning_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_hyperparameter_tuning_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_hyperparameter_tuning_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_hyperparameter_tuning_jobs + ] = mock_rpc + + request = {} + client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_hyperparameter_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_hyperparameter_tuning_jobs_rest_required_fields( request_type=job_service.ListHyperparameterTuningJobsRequest, ): @@ -17866,6 +21623,51 @@ def test_delete_hyperparameter_tuning_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_hyperparameter_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_hyperparameter_tuning_job + ] = mock_rpc + + request = {} + client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_hyperparameter_tuning_job_rest_required_fields( request_type=job_service.DeleteHyperparameterTuningJobRequest, ): @@ -18134,6 +21936,47 @@ def test_cancel_hyperparameter_tuning_job_rest(request_type): assert response is None +def test_cancel_hyperparameter_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_hyperparameter_tuning_job + ] = mock_rpc + + request = {} + client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_hyperparameter_tuning_job_rest_required_fields( request_type=job_service.CancelHyperparameterTuningJobRequest, ): @@ -18593,6 +22436,42 @@ def get_message_fields(field): assert response.enable_restricted_image_training is True +def test_create_nas_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_nas_job] = mock_rpc + + request = {} + client.create_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_nas_job_rest_required_fields( request_type=job_service.CreateNasJobRequest, ): @@ -18874,6 +22753,42 @@ def test_get_nas_job_rest(request_type): assert response.enable_restricted_image_training is True +def test_get_nas_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_nas_job] = mock_rpc + + request = {} + client.get_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_nas_job_rest_required_fields(request_type=job_service.GetNasJobRequest): transport_class = transports.JobServiceRestTransport @@ -19134,6 +23049,42 @@ def test_list_nas_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_nas_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_nas_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_nas_jobs] = mock_rpc + + request = {} + client.list_nas_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_nas_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_nas_jobs_rest_required_fields( request_type=job_service.ListNasJobsRequest, ): @@ -19467,6 +23418,46 @@ def test_delete_nas_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_nas_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_nas_job] = mock_rpc + + request = {} + client.delete_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_nas_job_rest_required_fields( request_type=job_service.DeleteNasJobRequest, ): @@ -19725,6 +23716,42 @@ def test_cancel_nas_job_rest(request_type): assert response is None +def test_cancel_nas_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.cancel_nas_job] = mock_rpc + + request = {} + client.cancel_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_nas_job_rest_required_fields( request_type=job_service.CancelNasJobRequest, ): @@ -19984,6 +24011,46 @@ def test_get_nas_trial_detail_rest(request_type): assert response.parameters == "parameters_value" +def test_get_nas_trial_detail_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_nas_trial_detail in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_nas_trial_detail + ] = mock_rpc + + request = {} + client.get_nas_trial_detail(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_nas_trial_detail(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_nas_trial_detail_rest_required_fields( request_type=job_service.GetNasTrialDetailRequest, ): @@ -20255,6 +24322,47 @@ def test_list_nas_trial_details_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_nas_trial_details_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_nas_trial_details + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_nas_trial_details + ] = mock_rpc + + request = {} + client.list_nas_trial_details(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_nas_trial_details(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_nas_trial_details_rest_required_fields( request_type=job_service.ListNasTrialDetailsRequest, ): @@ -20818,6 +24926,47 @@ def get_message_fields(field): assert response.disable_container_logging is True +def test_create_batch_prediction_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_batch_prediction_job + ] = mock_rpc + + request = {} + client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_batch_prediction_job_rest_required_fields( request_type=job_service.CreateBatchPredictionJobRequest, ): @@ -21116,6 +25265,47 @@ def test_get_batch_prediction_job_rest(request_type): assert response.disable_container_logging is True +def test_get_batch_prediction_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_batch_prediction_job + ] = mock_rpc + + request = {} + client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_batch_prediction_job_rest_required_fields( request_type=job_service.GetBatchPredictionJobRequest, ): @@ -21387,6 +25577,47 @@ def test_list_batch_prediction_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_batch_prediction_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_batch_prediction_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_batch_prediction_jobs + ] = mock_rpc + + request = {} + client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_batch_prediction_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_batch_prediction_jobs_rest_required_fields( request_type=job_service.ListBatchPredictionJobsRequest, ): @@ -21729,6 +25960,51 @@ def test_delete_batch_prediction_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_batch_prediction_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_batch_prediction_job + ] = mock_rpc + + request = {} + client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_batch_prediction_job_rest_required_fields( request_type=job_service.DeleteBatchPredictionJobRequest, ): @@ -21994,6 +26270,47 @@ def test_cancel_batch_prediction_job_rest(request_type): assert response is None +def test_cancel_batch_prediction_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_batch_prediction_job + ] = mock_rpc + + request = {} + client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_batch_prediction_job_rest_required_fields( request_type=job_service.CancelBatchPredictionJobRequest, ): @@ -22448,6 +26765,47 @@ def get_message_fields(field): assert response.enable_monitoring_pipeline_logs is True +def test_create_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.CreateModelDeploymentMonitoringJobRequest, ): @@ -22765,6 +27123,47 @@ def test_search_model_deployment_monitoring_stats_anomalies_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_model_deployment_monitoring_stats_anomalies_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_model_deployment_monitoring_stats_anomalies + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_model_deployment_monitoring_stats_anomalies + ] = mock_rpc + + request = {} + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_model_deployment_monitoring_stats_anomalies_rest_required_fields( request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, ): @@ -23184,6 +27583,47 @@ def test_get_model_deployment_monitoring_job_rest(request_type): assert response.enable_monitoring_pipeline_logs is True +def test_get_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.GetModelDeploymentMonitoringJobRequest, ): @@ -23472,6 +27912,47 @@ def test_list_model_deployment_monitoring_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_model_deployment_monitoring_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_deployment_monitoring_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_deployment_monitoring_jobs + ] = mock_rpc + + request = {} + client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_deployment_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_model_deployment_monitoring_jobs_rest_required_fields( request_type=job_service.ListModelDeploymentMonitoringJobsRequest, ): @@ -24009,6 +28490,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.UpdateModelDeploymentMonitoringJobRequest, ): @@ -24301,6 +28827,51 @@ def test_delete_model_deployment_monitoring_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.DeleteModelDeploymentMonitoringJobRequest, ): @@ -24577,6 +29148,47 @@ def test_pause_model_deployment_monitoring_job_rest(request_type): assert response is None +def test_pause_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.pause_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.pause_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.pause_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.pause_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_pause_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.PauseModelDeploymentMonitoringJobRequest, ): @@ -24843,6 +29455,47 @@ def test_resume_model_deployment_monitoring_job_rest(request_type): assert response is None +def test_resume_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.resume_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.resume_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.resume_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.resume_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_resume_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.ResumeModelDeploymentMonitoringJobRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_llm_utility_service.py b/tests/unit/gapic/aiplatform_v1/test_llm_utility_service.py index 1a92652d3b..8549e7a55f 100644 --- a/tests/unit/gapic/aiplatform_v1/test_llm_utility_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_llm_utility_service.py @@ -1218,6 +1218,9 @@ def test_count_tokens_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.count_tokens() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1242,6 +1245,9 @@ def test_count_tokens_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.count_tokens(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1251,6 +1257,41 @@ def test_count_tokens_non_empty_request_with_auto_populated_field(): ) +def test_count_tokens_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = LlmUtilityServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.count_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.count_tokens] = mock_rpc + request = {} + client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_count_tokens_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1275,6 +1316,52 @@ async def test_count_tokens_empty_call_async(): assert args[0] == prediction_service.CountTokensRequest() +@pytest.mark.asyncio +async def test_count_tokens_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = LlmUtilityServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.count_tokens + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.count_tokens + ] = mock_object + + request = {} + await client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_count_tokens_async( transport: str = "grpc_asyncio", request_type=prediction_service.CountTokensRequest @@ -1512,6 +1599,9 @@ def test_compute_tokens_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.compute_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.compute_tokens() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1535,6 +1625,9 @@ def test_compute_tokens_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.compute_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.compute_tokens(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1543,6 +1636,41 @@ def test_compute_tokens_non_empty_request_with_auto_populated_field(): ) +def test_compute_tokens_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = LlmUtilityServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.compute_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.compute_tokens] = mock_rpc + request = {} + client.compute_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.compute_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_compute_tokens_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1564,6 +1692,52 @@ async def test_compute_tokens_empty_call_async(): assert args[0] == llm_utility_service.ComputeTokensRequest() +@pytest.mark.asyncio +async def test_compute_tokens_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = LlmUtilityServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.compute_tokens + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.compute_tokens + ] = mock_object + + request = {} + await client.compute_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.compute_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_compute_tokens_async( transport: str = "grpc_asyncio", @@ -1796,6 +1970,42 @@ def test_count_tokens_rest(request_type): assert response.total_billable_characters == 2617 +def test_count_tokens_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = LlmUtilityServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.count_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.count_tokens] = mock_rpc + + request = {} + client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_count_tokens_rest_required_fields( request_type=prediction_service.CountTokensRequest, ): @@ -2079,6 +2289,42 @@ def test_compute_tokens_rest(request_type): assert isinstance(response, llm_utility_service.ComputeTokensResponse) +def test_compute_tokens_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = LlmUtilityServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.compute_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.compute_tokens] = mock_rpc + + request = {} + client.compute_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.compute_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_compute_tokens_rest_required_fields( request_type=llm_utility_service.ComputeTokensRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_match_service.py b/tests/unit/gapic/aiplatform_v1/test_match_service.py index bbd93c378d..3a9a4ac206 100644 --- a/tests/unit/gapic/aiplatform_v1/test_match_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_match_service.py @@ -1144,6 +1144,9 @@ def test_find_neighbors_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.find_neighbors), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.find_neighbors() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1168,6 +1171,9 @@ def test_find_neighbors_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.find_neighbors), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.find_neighbors(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1177,6 +1183,41 @@ def test_find_neighbors_non_empty_request_with_auto_populated_field(): ) +def test_find_neighbors_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MatchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.find_neighbors in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.find_neighbors] = mock_rpc + request = {} + client.find_neighbors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.find_neighbors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_find_neighbors_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1198,6 +1239,52 @@ async def test_find_neighbors_empty_call_async(): assert args[0] == match_service.FindNeighborsRequest() +@pytest.mark.asyncio +async def test_find_neighbors_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MatchServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.find_neighbors + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.find_neighbors + ] = mock_object + + request = {} + await client.find_neighbors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.find_neighbors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_find_neighbors_async( transport: str = "grpc_asyncio", request_type=match_service.FindNeighborsRequest @@ -1342,6 +1429,9 @@ def test_read_index_datapoints_empty_call(): with mock.patch.object( type(client.transport.read_index_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_index_datapoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1368,6 +1458,9 @@ def test_read_index_datapoints_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.read_index_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_index_datapoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1377,6 +1470,46 @@ def test_read_index_datapoints_non_empty_request_with_auto_populated_field(): ) +def test_read_index_datapoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MatchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_index_datapoints + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_index_datapoints + ] = mock_rpc + request = {} + client.read_index_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_index_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_index_datapoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1400,6 +1533,52 @@ async def test_read_index_datapoints_empty_call_async(): assert args[0] == match_service.ReadIndexDatapointsRequest() +@pytest.mark.asyncio +async def test_read_index_datapoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MatchServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_index_datapoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_index_datapoints + ] = mock_object + + request = {} + await client.read_index_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_index_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_index_datapoints_async( transport: str = "grpc_asyncio", @@ -1543,6 +1722,42 @@ def test_find_neighbors_rest(request_type): assert isinstance(response, match_service.FindNeighborsResponse) +def test_find_neighbors_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MatchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.find_neighbors in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.find_neighbors] = mock_rpc + + request = {} + client.find_neighbors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.find_neighbors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_find_neighbors_rest_required_fields( request_type=match_service.FindNeighborsRequest, ): @@ -1755,6 +1970,47 @@ def test_read_index_datapoints_rest(request_type): assert isinstance(response, match_service.ReadIndexDatapointsResponse) +def test_read_index_datapoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MatchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_index_datapoints + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_index_datapoints + ] = mock_rpc + + request = {} + client.read_index_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_index_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_index_datapoints_rest_required_fields( request_type=match_service.ReadIndexDatapointsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1/test_metadata_service.py index f2ec6f300c..efff9e1a3b 100644 --- a/tests/unit/gapic/aiplatform_v1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_metadata_service.py @@ -1219,6 +1219,9 @@ def test_create_metadata_store_empty_call(): with mock.patch.object( type(client.transport.create_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1245,6 +1248,9 @@ def test_create_metadata_store_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_metadata_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1254,6 +1260,50 @@ def test_create_metadata_store_non_empty_request_with_auto_populated_field(): ) +def test_create_metadata_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_metadata_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_metadata_store + ] = mock_rpc + request = {} + client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_metadata_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1277,6 +1327,56 @@ async def test_create_metadata_store_empty_call_async(): assert args[0] == metadata_service.CreateMetadataStoreRequest() +@pytest.mark.asyncio +async def test_create_metadata_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_metadata_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_metadata_store + ] = mock_object + + request = {} + await client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_metadata_store_async( transport: str = "grpc_asyncio", @@ -1539,6 +1639,9 @@ def test_get_metadata_store_empty_call(): with mock.patch.object( type(client.transport.get_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1564,6 +1667,9 @@ def test_get_metadata_store_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_metadata_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1572,6 +1678,45 @@ def test_get_metadata_store_non_empty_request_with_auto_populated_field(): ) +def test_get_metadata_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_metadata_store in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_metadata_store + ] = mock_rpc + request = {} + client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_metadata_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1598,6 +1743,52 @@ async def test_get_metadata_store_empty_call_async(): assert args[0] == metadata_service.GetMetadataStoreRequest() +@pytest.mark.asyncio +async def test_get_metadata_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_metadata_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_metadata_store + ] = mock_object + + request = {} + await client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_metadata_store_async( transport: str = "grpc_asyncio", @@ -1843,6 +2034,9 @@ def test_list_metadata_stores_empty_call(): with mock.patch.object( type(client.transport.list_metadata_stores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_metadata_stores() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1869,6 +2063,9 @@ def test_list_metadata_stores_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_metadata_stores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_metadata_stores(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1878,6 +2075,45 @@ def test_list_metadata_stores_non_empty_request_with_auto_populated_field(): ) +def test_list_metadata_stores_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_metadata_stores in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_metadata_stores + ] = mock_rpc + request = {} + client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_metadata_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_metadata_stores_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1903,6 +2139,52 @@ async def test_list_metadata_stores_empty_call_async(): assert args[0] == metadata_service.ListMetadataStoresRequest() +@pytest.mark.asyncio +async def test_list_metadata_stores_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_metadata_stores + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_metadata_stores + ] = mock_object + + request = {} + await client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_metadata_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_metadata_stores_async( transport: str = "grpc_asyncio", @@ -2341,6 +2623,9 @@ def test_delete_metadata_store_empty_call(): with mock.patch.object( type(client.transport.delete_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2366,6 +2651,9 @@ def test_delete_metadata_store_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_metadata_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2374,6 +2662,50 @@ def test_delete_metadata_store_non_empty_request_with_auto_populated_field(): ) +def test_delete_metadata_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_metadata_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_metadata_store + ] = mock_rpc + request = {} + client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_metadata_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2397,6 +2729,56 @@ async def test_delete_metadata_store_empty_call_async(): assert args[0] == metadata_service.DeleteMetadataStoreRequest() +@pytest.mark.asyncio +async def test_delete_metadata_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_metadata_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_metadata_store + ] = mock_object + + request = {} + await client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_metadata_store_async( transport: str = "grpc_asyncio", @@ -2647,6 +3029,9 @@ def test_create_artifact_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_artifact() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2671,6 +3056,9 @@ def test_create_artifact_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_artifact(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2680,6 +3068,41 @@ def test_create_artifact_non_empty_request_with_auto_populated_field(): ) +def test_create_artifact_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_artifact] = mock_rpc + request = {} + client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_artifact_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2710,6 +3133,52 @@ async def test_create_artifact_empty_call_async(): assert args[0] == metadata_service.CreateArtifactRequest() +@pytest.mark.asyncio +async def test_create_artifact_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_artifact + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_artifact + ] = mock_object + + request = {} + await client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_artifact_async( transport: str = "grpc_asyncio", request_type=metadata_service.CreateArtifactRequest @@ -2986,6 +3455,9 @@ def test_get_artifact_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_artifact() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3009,6 +3481,9 @@ def test_get_artifact_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_artifact(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3017,6 +3492,41 @@ def test_get_artifact_non_empty_request_with_auto_populated_field(): ) +def test_get_artifact_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_artifact] = mock_rpc + request = {} + client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_artifact_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3047,6 +3557,52 @@ async def test_get_artifact_empty_call_async(): assert args[0] == metadata_service.GetArtifactRequest() +@pytest.mark.asyncio +async def test_get_artifact_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_artifact + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_artifact + ] = mock_object + + request = {} + await client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_artifact_async( transport: str = "grpc_asyncio", request_type=metadata_service.GetArtifactRequest @@ -3285,6 +3841,9 @@ def test_list_artifacts_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_artifacts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3311,6 +3870,9 @@ def test_list_artifacts_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_artifacts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3322,18 +3884,53 @@ def test_list_artifacts_non_empty_request_with_auto_populated_field(): ) -@pytest.mark.asyncio -async def test_list_artifacts_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) +def test_list_artifacts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: - # Designate an appropriate return value for the call. + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_artifacts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_artifacts] = mock_rpc + request = {} + client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_artifacts_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: + # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( metadata_service.ListArtifactsResponse( next_page_token="next_page_token_value", @@ -3345,6 +3942,52 @@ async def test_list_artifacts_empty_call_async(): assert args[0] == metadata_service.ListArtifactsRequest() +@pytest.mark.asyncio +async def test_list_artifacts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_artifacts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_artifacts + ] = mock_object + + request = {} + await client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_artifacts_async( transport: str = "grpc_asyncio", request_type=metadata_service.ListArtifactsRequest @@ -3777,6 +4420,9 @@ def test_update_artifact_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_artifact() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3798,12 +4444,50 @@ def test_update_artifact_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_artifact(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateArtifactRequest() +def test_update_artifact_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_artifact] = mock_rpc + request = {} + client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_artifact_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3834,6 +4518,52 @@ async def test_update_artifact_empty_call_async(): assert args[0] == metadata_service.UpdateArtifactRequest() +@pytest.mark.asyncio +async def test_update_artifact_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_artifact + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_artifact + ] = mock_object + + request = {} + await client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_artifact_async( transport: str = "grpc_asyncio", request_type=metadata_service.UpdateArtifactRequest @@ -4083,6 +4813,9 @@ def test_delete_artifact_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_artifact() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4107,6 +4840,9 @@ def test_delete_artifact_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_artifact(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4116,6 +4852,45 @@ def test_delete_artifact_non_empty_request_with_auto_populated_field(): ) +def test_delete_artifact_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_artifact] = mock_rpc + request = {} + client.delete_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_artifact_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4137,6 +4912,56 @@ async def test_delete_artifact_empty_call_async(): assert args[0] == metadata_service.DeleteArtifactRequest() +@pytest.mark.asyncio +async def test_delete_artifact_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_artifact + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_artifact + ] = mock_object + + request = {} + await client.delete_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_artifact_async( transport: str = "grpc_asyncio", request_type=metadata_service.DeleteArtifactRequest @@ -4359,6 +5184,9 @@ def test_purge_artifacts_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_artifacts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4383,6 +5211,9 @@ def test_purge_artifacts_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_artifacts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4392,6 +5223,45 @@ def test_purge_artifacts_non_empty_request_with_auto_populated_field(): ) +def test_purge_artifacts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_artifacts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.purge_artifacts] = mock_rpc + request = {} + client.purge_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_purge_artifacts_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4413,6 +5283,56 @@ async def test_purge_artifacts_empty_call_async(): assert args[0] == metadata_service.PurgeArtifactsRequest() +@pytest.mark.asyncio +async def test_purge_artifacts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.purge_artifacts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.purge_artifacts + ] = mock_object + + request = {} + await client.purge_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.purge_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_purge_artifacts_async( transport: str = "grpc_asyncio", request_type=metadata_service.PurgeArtifactsRequest @@ -4650,6 +5570,9 @@ def test_create_context_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_context() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4674,6 +5597,9 @@ def test_create_context_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_context(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4683,6 +5609,41 @@ def test_create_context_non_empty_request_with_auto_populated_field(): ) +def test_create_context_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_context] = mock_rpc + request = {} + client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_context_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4712,6 +5673,52 @@ async def test_create_context_empty_call_async(): assert args[0] == metadata_service.CreateContextRequest() +@pytest.mark.asyncio +async def test_create_context_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_context + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_context + ] = mock_object + + request = {} + await client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_context_async( transport: str = "grpc_asyncio", request_type=metadata_service.CreateContextRequest @@ -4980,6 +5987,9 @@ def test_get_context_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_context() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5003,6 +6013,9 @@ def test_get_context_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_context(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5011,6 +6024,41 @@ def test_get_context_non_empty_request_with_auto_populated_field(): ) +def test_get_context_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_context] = mock_rpc + request = {} + client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_context_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5040,6 +6088,52 @@ async def test_get_context_empty_call_async(): assert args[0] == metadata_service.GetContextRequest() +@pytest.mark.asyncio +async def test_get_context_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_context + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_context + ] = mock_object + + request = {} + await client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_context_async( transport: str = "grpc_asyncio", request_type=metadata_service.GetContextRequest @@ -5276,6 +6370,9 @@ def test_list_contexts_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_contexts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5302,6 +6399,9 @@ def test_list_contexts_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_contexts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5313,6 +6413,41 @@ def test_list_contexts_non_empty_request_with_auto_populated_field(): ) +def test_list_contexts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_contexts] = mock_rpc + request = {} + client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_contexts_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5337,7 +6472,53 @@ async def test_list_contexts_empty_call_async(): @pytest.mark.asyncio -async def test_list_contexts_async( +async def test_list_contexts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_contexts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_contexts + ] = mock_object + + request = {} + await client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_contexts_async( transport: str = "grpc_asyncio", request_type=metadata_service.ListContextsRequest ): client = MetadataServiceAsyncClient( @@ -5766,6 +6947,9 @@ def test_update_context_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_context() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5787,12 +6971,50 @@ def test_update_context_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_context(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateContextRequest() +def test_update_context_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_context] = mock_rpc + request = {} + client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_context_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5822,6 +7044,52 @@ async def test_update_context_empty_call_async(): assert args[0] == metadata_service.UpdateContextRequest() +@pytest.mark.asyncio +async def test_update_context_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_context + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_context + ] = mock_object + + request = {} + await client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_context_async( transport: str = "grpc_asyncio", request_type=metadata_service.UpdateContextRequest @@ -6065,6 +7333,9 @@ def test_delete_context_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_context() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6089,6 +7360,9 @@ def test_delete_context_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_context(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6098,6 +7372,45 @@ def test_delete_context_non_empty_request_with_auto_populated_field(): ) +def test_delete_context_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_context] = mock_rpc + request = {} + client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_context_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6119,6 +7432,56 @@ async def test_delete_context_empty_call_async(): assert args[0] == metadata_service.DeleteContextRequest() +@pytest.mark.asyncio +async def test_delete_context_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_context + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_context + ] = mock_object + + request = {} + await client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_context_async( transport: str = "grpc_asyncio", request_type=metadata_service.DeleteContextRequest @@ -6341,6 +7704,9 @@ def test_purge_contexts_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_contexts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6365,6 +7731,9 @@ def test_purge_contexts_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_contexts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6374,6 +7743,45 @@ def test_purge_contexts_non_empty_request_with_auto_populated_field(): ) +def test_purge_contexts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.purge_contexts] = mock_rpc + request = {} + client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_purge_contexts_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6395,6 +7803,56 @@ async def test_purge_contexts_empty_call_async(): assert args[0] == metadata_service.PurgeContextsRequest() +@pytest.mark.asyncio +async def test_purge_contexts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.purge_contexts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.purge_contexts + ] = mock_object + + request = {} + await client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.purge_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_purge_contexts_async( transport: str = "grpc_asyncio", request_type=metadata_service.PurgeContextsRequest @@ -6623,6 +8081,9 @@ def test_add_context_artifacts_and_executions_empty_call(): with mock.patch.object( type(client.transport.add_context_artifacts_and_executions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_context_artifacts_and_executions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6648,6 +8109,9 @@ def test_add_context_artifacts_and_executions_non_empty_request_with_auto_popula with mock.patch.object( type(client.transport.add_context_artifacts_and_executions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_context_artifacts_and_executions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6656,6 +8120,46 @@ def test_add_context_artifacts_and_executions_non_empty_request_with_auto_popula ) +def test_add_context_artifacts_and_executions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_context_artifacts_and_executions + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_context_artifacts_and_executions + ] = mock_rpc + request = {} + client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_context_artifacts_and_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6679,6 +8183,52 @@ async def test_add_context_artifacts_and_executions_empty_call_async(): assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.add_context_artifacts_and_executions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.add_context_artifacts_and_executions + ] = mock_object + + request = {} + await client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.add_context_artifacts_and_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_async( transport: str = "grpc_asyncio", @@ -6938,6 +8488,9 @@ def test_add_context_children_empty_call(): with mock.patch.object( type(client.transport.add_context_children), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_context_children() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6963,6 +8516,9 @@ def test_add_context_children_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.add_context_children), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_context_children(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6971,6 +8527,45 @@ def test_add_context_children_non_empty_request_with_auto_populated_field(): ) +def test_add_context_children_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_context_children in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_context_children + ] = mock_rpc + request = {} + client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_add_context_children_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6994,6 +8589,52 @@ async def test_add_context_children_empty_call_async(): assert args[0] == metadata_service.AddContextChildrenRequest() +@pytest.mark.asyncio +async def test_add_context_children_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.add_context_children + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.add_context_children + ] = mock_object + + request = {} + await client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.add_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_add_context_children_async( transport: str = "grpc_asyncio", @@ -7241,6 +8882,9 @@ def test_remove_context_children_empty_call(): with mock.patch.object( type(client.transport.remove_context_children), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.remove_context_children() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7266,6 +8910,9 @@ def test_remove_context_children_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.remove_context_children), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.remove_context_children(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7274,6 +8921,46 @@ def test_remove_context_children_non_empty_request_with_auto_populated_field(): ) +def test_remove_context_children_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.remove_context_children + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.remove_context_children + ] = mock_rpc + request = {} + client.remove_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.remove_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_remove_context_children_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7298,24 +8985,70 @@ async def test_remove_context_children_empty_call_async(): @pytest.mark.asyncio -async def test_remove_context_children_async( +async def test_remove_context_children_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=metadata_service.RemoveContextChildrenRequest, ): - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.remove_context_children), "__call__" - ) as call: - # Designate an appropriate return value for the call. + # Ensure method has been cached + assert ( + client._client._transport.remove_context_children + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.remove_context_children + ] = mock_object + + request = {} + await client.remove_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.remove_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_remove_context_children_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.RemoveContextChildrenRequest, +): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.remove_context_children), "__call__" + ) as call: + # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( metadata_service.RemoveContextChildrenResponse() ) @@ -7544,6 +9277,9 @@ def test_query_context_lineage_subgraph_empty_call(): with mock.patch.object( type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_context_lineage_subgraph() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7569,6 +9305,9 @@ def test_query_context_lineage_subgraph_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_context_lineage_subgraph(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7577,6 +9316,46 @@ def test_query_context_lineage_subgraph_non_empty_request_with_auto_populated_fi ) +def test_query_context_lineage_subgraph_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_context_lineage_subgraph + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_context_lineage_subgraph + ] = mock_rpc + request = {} + client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_context_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_context_lineage_subgraph_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7600,6 +9379,52 @@ async def test_query_context_lineage_subgraph_empty_call_async(): assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() +@pytest.mark.asyncio +async def test_query_context_lineage_subgraph_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_context_lineage_subgraph + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_context_lineage_subgraph + ] = mock_object + + request = {} + await client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_context_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_query_context_lineage_subgraph_async( transport: str = "grpc_asyncio", @@ -7848,6 +9673,9 @@ def test_create_execution_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_execution() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7872,6 +9700,9 @@ def test_create_execution_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_execution(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7881,6 +9712,43 @@ def test_create_execution_non_empty_request_with_auto_populated_field(): ) +def test_create_execution_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_execution + ] = mock_rpc + request = {} + client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_execution_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7910,6 +9778,52 @@ async def test_create_execution_empty_call_async(): assert args[0] == metadata_service.CreateExecutionRequest() +@pytest.mark.asyncio +async def test_create_execution_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_execution + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_execution + ] = mock_object + + request = {} + await client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_execution_async( transport: str = "grpc_asyncio", @@ -8183,6 +10097,9 @@ def test_get_execution_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_execution() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8206,6 +10123,9 @@ def test_get_execution_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_execution(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8214,6 +10134,41 @@ def test_get_execution_non_empty_request_with_auto_populated_field(): ) +def test_get_execution_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_execution] = mock_rpc + request = {} + client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_execution_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8243,6 +10198,52 @@ async def test_get_execution_empty_call_async(): assert args[0] == metadata_service.GetExecutionRequest() +@pytest.mark.asyncio +async def test_get_execution_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_execution + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_execution + ] = mock_object + + request = {} + await client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_execution_async( transport: str = "grpc_asyncio", request_type=metadata_service.GetExecutionRequest @@ -8479,6 +10480,9 @@ def test_list_executions_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_executions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8505,6 +10509,9 @@ def test_list_executions_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_executions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8516,6 +10523,41 @@ def test_list_executions_non_empty_request_with_auto_populated_field(): ) +def test_list_executions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_executions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_executions] = mock_rpc + request = {} + client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_executions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8539,6 +10581,52 @@ async def test_list_executions_empty_call_async(): assert args[0] == metadata_service.ListExecutionsRequest() +@pytest.mark.asyncio +async def test_list_executions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_executions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_executions + ] = mock_object + + request = {} + await client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_executions_async( transport: str = "grpc_asyncio", request_type=metadata_service.ListExecutionsRequest @@ -8969,6 +11057,9 @@ def test_update_execution_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_execution() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8990,12 +11081,52 @@ def test_update_execution_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_execution(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateExecutionRequest() +def test_update_execution_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_execution + ] = mock_rpc + request = {} + client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_execution_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9025,6 +11156,52 @@ async def test_update_execution_empty_call_async(): assert args[0] == metadata_service.UpdateExecutionRequest() +@pytest.mark.asyncio +async def test_update_execution_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_execution + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_execution + ] = mock_object + + request = {} + await client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_execution_async( transport: str = "grpc_asyncio", @@ -9273,6 +11450,9 @@ def test_delete_execution_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_execution() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9297,6 +11477,9 @@ def test_delete_execution_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_execution(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9306,6 +11489,47 @@ def test_delete_execution_non_empty_request_with_auto_populated_field(): ) +def test_delete_execution_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_execution + ] = mock_rpc + request = {} + client.delete_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_execution_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9327,6 +11551,56 @@ async def test_delete_execution_empty_call_async(): assert args[0] == metadata_service.DeleteExecutionRequest() +@pytest.mark.asyncio +async def test_delete_execution_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_execution + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_execution + ] = mock_object + + request = {} + await client.delete_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_execution_async( transport: str = "grpc_asyncio", @@ -9550,6 +11824,9 @@ def test_purge_executions_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_executions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9574,6 +11851,9 @@ def test_purge_executions_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_executions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9583,6 +11863,47 @@ def test_purge_executions_non_empty_request_with_auto_populated_field(): ) +def test_purge_executions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_executions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.purge_executions + ] = mock_rpc + request = {} + client.purge_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_purge_executions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9604,6 +11925,56 @@ async def test_purge_executions_empty_call_async(): assert args[0] == metadata_service.PurgeExecutionsRequest() +@pytest.mark.asyncio +async def test_purge_executions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.purge_executions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.purge_executions + ] = mock_object + + request = {} + await client.purge_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.purge_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_purge_executions_async( transport: str = "grpc_asyncio", @@ -9831,6 +12202,9 @@ def test_add_execution_events_empty_call(): with mock.patch.object( type(client.transport.add_execution_events), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_execution_events() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9856,6 +12230,9 @@ def test_add_execution_events_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.add_execution_events), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_execution_events(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9864,6 +12241,45 @@ def test_add_execution_events_non_empty_request_with_auto_populated_field(): ) +def test_add_execution_events_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_execution_events in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_execution_events + ] = mock_rpc + request = {} + client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_execution_events(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_add_execution_events_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9887,6 +12303,52 @@ async def test_add_execution_events_empty_call_async(): assert args[0] == metadata_service.AddExecutionEventsRequest() +@pytest.mark.asyncio +async def test_add_execution_events_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.add_execution_events + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.add_execution_events + ] = mock_object + + request = {} + await client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.add_execution_events(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_add_execution_events_async( transport: str = "grpc_asyncio", @@ -10134,6 +12596,9 @@ def test_query_execution_inputs_and_outputs_empty_call(): with mock.patch.object( type(client.transport.query_execution_inputs_and_outputs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_execution_inputs_and_outputs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10159,6 +12624,9 @@ def test_query_execution_inputs_and_outputs_non_empty_request_with_auto_populate with mock.patch.object( type(client.transport.query_execution_inputs_and_outputs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_execution_inputs_and_outputs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10167,6 +12635,46 @@ def test_query_execution_inputs_and_outputs_non_empty_request_with_auto_populate ) +def test_query_execution_inputs_and_outputs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_execution_inputs_and_outputs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_execution_inputs_and_outputs + ] = mock_rpc + request = {} + client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_execution_inputs_and_outputs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_execution_inputs_and_outputs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10176,18 +12684,64 @@ async def test_query_execution_inputs_and_outputs_empty_call_async(): transport="grpc_asyncio", ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - lineage_subgraph.LineageSubgraph() - ) - response = await client.query_execution_inputs_and_outputs() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) + response = await client.query_execution_inputs_and_outputs() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + + +@pytest.mark.asyncio +async def test_query_execution_inputs_and_outputs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_execution_inputs_and_outputs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_execution_inputs_and_outputs + ] = mock_object + + request = {} + await client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_execution_inputs_and_outputs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio @@ -10441,6 +12995,9 @@ def test_create_metadata_schema_empty_call(): with mock.patch.object( type(client.transport.create_metadata_schema), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_metadata_schema() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10467,6 +13024,9 @@ def test_create_metadata_schema_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_metadata_schema), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_metadata_schema(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10476,6 +13036,46 @@ def test_create_metadata_schema_non_empty_request_with_auto_populated_field(): ) +def test_create_metadata_schema_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_metadata_schema + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_metadata_schema + ] = mock_rpc + request = {} + client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_metadata_schema_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10505,6 +13105,52 @@ async def test_create_metadata_schema_empty_call_async(): assert args[0] == metadata_service.CreateMetadataSchemaRequest() +@pytest.mark.asyncio +async def test_create_metadata_schema_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_metadata_schema + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_metadata_schema + ] = mock_object + + request = {} + await client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_metadata_schema_async( transport: str = "grpc_asyncio", @@ -10790,6 +13436,9 @@ def test_get_metadata_schema_empty_call(): with mock.patch.object( type(client.transport.get_metadata_schema), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_metadata_schema() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10815,6 +13464,9 @@ def test_get_metadata_schema_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_metadata_schema), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_metadata_schema(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10823,6 +13475,45 @@ def test_get_metadata_schema_non_empty_request_with_auto_populated_field(): ) +def test_get_metadata_schema_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_metadata_schema in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_metadata_schema + ] = mock_rpc + request = {} + client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_metadata_schema_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10852,6 +13543,52 @@ async def test_get_metadata_schema_empty_call_async(): assert args[0] == metadata_service.GetMetadataSchemaRequest() +@pytest.mark.asyncio +async def test_get_metadata_schema_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_metadata_schema + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_metadata_schema + ] = mock_object + + request = {} + await client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_metadata_schema_async( transport: str = "grpc_asyncio", @@ -11106,6 +13843,9 @@ def test_list_metadata_schemas_empty_call(): with mock.patch.object( type(client.transport.list_metadata_schemas), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_metadata_schemas() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11133,6 +13873,9 @@ def test_list_metadata_schemas_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_metadata_schemas), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_metadata_schemas(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11143,6 +13886,46 @@ def test_list_metadata_schemas_non_empty_request_with_auto_populated_field(): ) +def test_list_metadata_schemas_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_metadata_schemas + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_metadata_schemas + ] = mock_rpc + request = {} + client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_metadata_schemas(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_metadata_schemas_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11168,6 +13951,52 @@ async def test_list_metadata_schemas_empty_call_async(): assert args[0] == metadata_service.ListMetadataSchemasRequest() +@pytest.mark.asyncio +async def test_list_metadata_schemas_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_metadata_schemas + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_metadata_schemas + ] = mock_object + + request = {} + await client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_metadata_schemas(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_metadata_schemas_async( transport: str = "grpc_asyncio", @@ -11606,6 +14435,9 @@ def test_query_artifact_lineage_subgraph_empty_call(): with mock.patch.object( type(client.transport.query_artifact_lineage_subgraph), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_artifact_lineage_subgraph() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11632,6 +14464,9 @@ def test_query_artifact_lineage_subgraph_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.query_artifact_lineage_subgraph), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_artifact_lineage_subgraph(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11641,6 +14476,46 @@ def test_query_artifact_lineage_subgraph_non_empty_request_with_auto_populated_f ) +def test_query_artifact_lineage_subgraph_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_artifact_lineage_subgraph + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_artifact_lineage_subgraph + ] = mock_rpc + request = {} + client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_artifact_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_artifact_lineage_subgraph_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11664,6 +14539,52 @@ async def test_query_artifact_lineage_subgraph_empty_call_async(): assert args[0] == metadata_service.QueryArtifactLineageSubgraphRequest() +@pytest.mark.asyncio +async def test_query_artifact_lineage_subgraph_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_artifact_lineage_subgraph + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_artifact_lineage_subgraph + ] = mock_object + + request = {} + await client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_artifact_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_query_artifact_lineage_subgraph_async( transport: str = "grpc_asyncio", @@ -11966,6 +14887,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_metadata_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_metadata_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_metadata_store + ] = mock_rpc + + request = {} + client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_metadata_store_rest_required_fields( request_type=metadata_service.CreateMetadataStoreRequest, ): @@ -12247,6 +15213,46 @@ def test_get_metadata_store_rest(request_type): assert response.description == "description_value" +def test_get_metadata_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_metadata_store in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_metadata_store + ] = mock_rpc + + request = {} + client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_metadata_store_rest_required_fields( request_type=metadata_service.GetMetadataStoreRequest, ): @@ -12516,6 +15522,46 @@ def test_list_metadata_stores_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_metadata_stores_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_metadata_stores in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_metadata_stores + ] = mock_rpc + + request = {} + client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_metadata_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_metadata_stores_rest_required_fields( request_type=metadata_service.ListMetadataStoresRequest, ): @@ -12832,22 +15878,67 @@ def test_delete_metadata_store_rest(request_type): request_init = {"name": "projects/sample1/locations/sample2/metadataStores/sample3"} request = request_type(**request_init) - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), "request") as req: - # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_metadata_store(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_delete_metadata_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_metadata_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_metadata_store + ] = mock_rpc + + request = {} + client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - json_return_value = json_format.MessageToJson(return_value) + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.delete_metadata_store(request) + client.delete_metadata_store(request) - # Establish that the response is the type that we expect. - assert response.operation.name == "operations/spam" + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_delete_metadata_store_rest_required_fields( @@ -13215,6 +16306,42 @@ def get_message_fields(field): assert response.description == "description_value" +def test_create_artifact_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_artifact] = mock_rpc + + request = {} + client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_artifact_rest_required_fields( request_type=metadata_service.CreateArtifactRequest, ): @@ -13517,6 +16644,42 @@ def test_get_artifact_rest(request_type): assert response.description == "description_value" +def test_get_artifact_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_artifact] = mock_rpc + + request = {} + client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_artifact_rest_required_fields( request_type=metadata_service.GetArtifactRequest, ): @@ -13788,6 +16951,42 @@ def test_list_artifacts_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_artifacts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_artifacts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_artifacts] = mock_rpc + + request = {} + client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_artifacts_rest_required_fields( request_type=metadata_service.ListArtifactsRequest, ): @@ -14236,6 +17435,42 @@ def get_message_fields(field): assert response.description == "description_value" +def test_update_artifact_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_artifact] = mock_rpc + + request = {} + client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_artifact_rest_required_fields( request_type=metadata_service.UpdateArtifactRequest, ): @@ -14521,6 +17756,46 @@ def test_delete_artifact_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_artifact_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_artifact] = mock_rpc + + request = {} + client.delete_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_artifact_rest_required_fields( request_type=metadata_service.DeleteArtifactRequest, ): @@ -14788,6 +18063,46 @@ def test_purge_artifacts_rest(request_type): assert response.operation.name == "operations/spam" +def test_purge_artifacts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_artifacts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.purge_artifacts] = mock_rpc + + request = {} + client.purge_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_purge_artifacts_rest_required_fields( request_type=metadata_service.PurgeArtifactsRequest, ): @@ -15163,6 +18478,42 @@ def get_message_fields(field): assert response.description == "description_value" +def test_create_context_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_context] = mock_rpc + + request = {} + client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_context_rest_required_fields( request_type=metadata_service.CreateContextRequest, ): @@ -15461,6 +18812,42 @@ def test_get_context_rest(request_type): assert response.description == "description_value" +def test_get_context_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_context] = mock_rpc + + request = {} + client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_context_rest_required_fields( request_type=metadata_service.GetContextRequest, ): @@ -15732,6 +19119,42 @@ def test_list_contexts_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_contexts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_contexts] = mock_rpc + + request = {} + client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_contexts_rest_required_fields( request_type=metadata_service.ListContextsRequest, ): @@ -16177,6 +19600,42 @@ def get_message_fields(field): assert response.description == "description_value" +def test_update_context_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_context] = mock_rpc + + request = {} + client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_context_rest_required_fields( request_type=metadata_service.UpdateContextRequest, ): @@ -16460,6 +19919,46 @@ def test_delete_context_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_context_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_context] = mock_rpc + + request = {} + client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_context_rest_required_fields( request_type=metadata_service.DeleteContextRequest, ): @@ -16740,6 +20239,46 @@ def test_purge_contexts_rest(request_type): assert response.operation.name == "operations/spam" +def test_purge_contexts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.purge_contexts] = mock_rpc + + request = {} + client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_purge_contexts_rest_required_fields( request_type=metadata_service.PurgeContextsRequest, ): @@ -17024,6 +20563,47 @@ def test_add_context_artifacts_and_executions_rest(request_type): ) +def test_add_context_artifacts_and_executions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_context_artifacts_and_executions + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_context_artifacts_and_executions + ] = mock_rpc + + request = {} + client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_context_artifacts_and_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_add_context_artifacts_and_executions_rest_required_fields( request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, ): @@ -17297,19 +20877,59 @@ def test_add_context_children_rest(request_type): # Designate an appropriate value for the returned response. return_value = metadata_service.AddContextChildrenResponse() - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - # Convert return value to protobuf type - return_value = metadata_service.AddContextChildrenResponse.pb(return_value) - json_return_value = json_format.MessageToJson(return_value) + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = metadata_service.AddContextChildrenResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.add_context_children(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_service.AddContextChildrenResponse) + + +def test_add_context_children_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_context_children in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_context_children + ] = mock_rpc + + request = {} + client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.add_context_children(request) + client.add_context_children(request) - # Establish that the response is the type that we expect. - assert isinstance(response, metadata_service.AddContextChildrenResponse) + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_add_context_children_rest_required_fields( @@ -17585,6 +21205,47 @@ def test_remove_context_children_rest(request_type): assert isinstance(response, metadata_service.RemoveContextChildrenResponse) +def test_remove_context_children_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.remove_context_children + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.remove_context_children + ] = mock_rpc + + request = {} + client.remove_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.remove_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_remove_context_children_rest_required_fields( request_type=metadata_service.RemoveContextChildrenRequest, ): @@ -17862,6 +21523,47 @@ def test_query_context_lineage_subgraph_rest(request_type): assert isinstance(response, lineage_subgraph.LineageSubgraph) +def test_query_context_lineage_subgraph_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_context_lineage_subgraph + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_context_lineage_subgraph + ] = mock_rpc + + request = {} + client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_context_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_context_lineage_subgraph_rest_required_fields( request_type=metadata_service.QueryContextLineageSubgraphRequest, ): @@ -18230,6 +21932,44 @@ def get_message_fields(field): assert response.description == "description_value" +def test_create_execution_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_execution + ] = mock_rpc + + request = {} + client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_execution_rest_required_fields( request_type=metadata_service.CreateExecutionRequest, ): @@ -18530,6 +22270,42 @@ def test_get_execution_rest(request_type): assert response.description == "description_value" +def test_get_execution_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_execution] = mock_rpc + + request = {} + client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_execution_rest_required_fields( request_type=metadata_service.GetExecutionRequest, ): @@ -18801,6 +22577,42 @@ def test_list_executions_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_executions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_executions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_executions] = mock_rpc + + request = {} + client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_executions_rest_required_fields( request_type=metadata_service.ListExecutionsRequest, ): @@ -19246,6 +23058,44 @@ def get_message_fields(field): assert response.description == "description_value" +def test_update_execution_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_execution + ] = mock_rpc + + request = {} + client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_execution_rest_required_fields( request_type=metadata_service.UpdateExecutionRequest, ): @@ -19531,6 +23381,48 @@ def test_delete_execution_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_execution_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_execution + ] = mock_rpc + + request = {} + client.delete_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_execution_rest_required_fields( request_type=metadata_service.DeleteExecutionRequest, ): @@ -19798,6 +23690,48 @@ def test_purge_executions_rest(request_type): assert response.operation.name == "operations/spam" +def test_purge_executions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_executions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.purge_executions + ] = mock_rpc + + request = {} + client.purge_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_purge_executions_rest_required_fields( request_type=metadata_service.PurgeExecutionsRequest, ): @@ -20078,6 +24012,46 @@ def test_add_execution_events_rest(request_type): assert isinstance(response, metadata_service.AddExecutionEventsResponse) +def test_add_execution_events_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_execution_events in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_execution_events + ] = mock_rpc + + request = {} + client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_execution_events(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_add_execution_events_rest_required_fields( request_type=metadata_service.AddExecutionEventsRequest, ): @@ -20351,6 +24325,47 @@ def test_query_execution_inputs_and_outputs_rest(request_type): assert isinstance(response, lineage_subgraph.LineageSubgraph) +def test_query_execution_inputs_and_outputs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_execution_inputs_and_outputs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_execution_inputs_and_outputs + ] = mock_rpc + + request = {} + client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_execution_inputs_and_outputs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_execution_inputs_and_outputs_rest_required_fields( request_type=metadata_service.QueryExecutionInputsAndOutputsRequest, ): @@ -20719,6 +24734,47 @@ def get_message_fields(field): assert response.description == "description_value" +def test_create_metadata_schema_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_metadata_schema + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_metadata_schema + ] = mock_rpc + + request = {} + client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_metadata_schema_rest_required_fields( request_type=metadata_service.CreateMetadataSchemaRequest, ): @@ -21018,6 +25074,46 @@ def test_get_metadata_schema_rest(request_type): assert response.description == "description_value" +def test_get_metadata_schema_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_metadata_schema in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_metadata_schema + ] = mock_rpc + + request = {} + client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_metadata_schema_rest_required_fields( request_type=metadata_service.GetMetadataSchemaRequest, ): @@ -21291,6 +25387,47 @@ def test_list_metadata_schemas_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_metadata_schemas_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_metadata_schemas + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_metadata_schemas + ] = mock_rpc + + request = {} + client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_metadata_schemas(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_metadata_schemas_rest_required_fields( request_type=metadata_service.ListMetadataSchemasRequest, ): @@ -21639,6 +25776,47 @@ def test_query_artifact_lineage_subgraph_rest(request_type): assert isinstance(response, lineage_subgraph.LineageSubgraph) +def test_query_artifact_lineage_subgraph_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_artifact_lineage_subgraph + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_artifact_lineage_subgraph + ] = mock_rpc + + request = {} + client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_artifact_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_artifact_lineage_subgraph_rest_required_fields( request_type=metadata_service.QueryArtifactLineageSubgraphRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index c3ccb36026..1af26857c4 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -1211,6 +1211,9 @@ def test_search_migratable_resources_empty_call(): with mock.patch.object( type(client.transport.search_migratable_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_migratable_resources() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1238,6 +1241,9 @@ def test_search_migratable_resources_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.search_migratable_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_migratable_resources(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1248,6 +1254,46 @@ def test_search_migratable_resources_non_empty_request_with_auto_populated_field ) +def test_search_migratable_resources_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_migratable_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_migratable_resources + ] = mock_rpc + request = {} + client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_migratable_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_migratable_resources_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1273,6 +1319,52 @@ async def test_search_migratable_resources_empty_call_async(): assert args[0] == migration_service.SearchMigratableResourcesRequest() +@pytest.mark.asyncio +async def test_search_migratable_resources_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MigrationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_migratable_resources + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_migratable_resources + ] = mock_object + + request = {} + await client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_migratable_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_migratable_resources_async( transport: str = "grpc_asyncio", @@ -1715,6 +1807,9 @@ def test_batch_migrate_resources_empty_call(): with mock.patch.object( type(client.transport.batch_migrate_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_migrate_resources() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1740,6 +1835,9 @@ def test_batch_migrate_resources_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.batch_migrate_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_migrate_resources(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1748,6 +1846,50 @@ def test_batch_migrate_resources_non_empty_request_with_auto_populated_field(): ) +def test_batch_migrate_resources_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_migrate_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_migrate_resources + ] = mock_rpc + request = {} + client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_migrate_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_migrate_resources_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1771,6 +1913,56 @@ async def test_batch_migrate_resources_empty_call_async(): assert args[0] == migration_service.BatchMigrateResourcesRequest() +@pytest.mark.asyncio +async def test_batch_migrate_resources_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MigrationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_migrate_resources + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_migrate_resources + ] = mock_object + + request = {} + await client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_migrate_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_migrate_resources_async( transport: str = "grpc_asyncio", @@ -2049,6 +2241,47 @@ def test_search_migratable_resources_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_migratable_resources_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_migratable_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_migratable_resources + ] = mock_rpc + + request = {} + client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_migratable_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_migratable_resources_rest_required_fields( request_type=migration_service.SearchMigratableResourcesRequest, ): @@ -2379,6 +2612,51 @@ def test_batch_migrate_resources_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_migrate_resources_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_migrate_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_migrate_resources + ] = mock_rpc + + request = {} + client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_migrate_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_migrate_resources_rest_required_fields( request_type=migration_service.BatchMigrateResourcesRequest, ): @@ -3255,19 +3533,22 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - dataset = "mussel" - expected = "projects/{project}/datasets/{dataset}".format( + location = "mussel" + dataset = "winkle" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -3277,22 +3558,19 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "scallop" - location = "abalone" - dataset = "squid" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project = "squid" + dataset = "clam" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "clam", - "location": "whelk", + "project": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_model_garden_service.py b/tests/unit/gapic/aiplatform_v1/test_model_garden_service.py index 81f7646e3c..7391745d33 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_garden_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_garden_service.py @@ -1241,6 +1241,9 @@ def test_get_publisher_model_empty_call(): with mock.patch.object( type(client.transport.get_publisher_model), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_publisher_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1267,6 +1270,9 @@ def test_get_publisher_model_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_publisher_model), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_publisher_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1276,6 +1282,45 @@ def test_get_publisher_model_non_empty_request_with_auto_populated_field(): ) +def test_get_publisher_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_publisher_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_publisher_model + ] = mock_rpc + request = {} + client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_publisher_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_publisher_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1307,6 +1352,52 @@ async def test_get_publisher_model_empty_call_async(): assert args[0] == model_garden_service.GetPublisherModelRequest() +@pytest.mark.asyncio +async def test_get_publisher_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_publisher_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_publisher_model + ] = mock_object + + request = {} + await client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_publisher_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_publisher_model_async( transport: str = "grpc_asyncio", @@ -1580,6 +1671,46 @@ def test_get_publisher_model_rest(request_type): assert response.publisher_model_template == "publisher_model_template_value" +def test_get_publisher_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_publisher_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_publisher_model + ] = mock_rpc + + request = {} + client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_publisher_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_publisher_model_rest_required_fields( request_type=model_garden_service.GetPublisherModelRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py index 167e27f741..8198ef26bf 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -1167,6 +1167,9 @@ def test_upload_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upload_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1193,6 +1196,9 @@ def test_upload_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upload_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1204,6 +1210,45 @@ def test_upload_model_non_empty_request_with_auto_populated_field(): ) +def test_upload_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upload_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.upload_model] = mock_rpc + request = {} + client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.upload_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_upload_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1225,6 +1270,56 @@ async def test_upload_model_empty_call_async(): assert args[0] == model_service.UploadModelRequest() +@pytest.mark.asyncio +async def test_upload_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.upload_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.upload_model + ] = mock_object + + request = {} + await client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.upload_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_upload_model_async( transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest @@ -1496,6 +1591,9 @@ def test_get_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1519,6 +1617,9 @@ def test_get_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1527,6 +1628,41 @@ def test_get_model_non_empty_request_with_auto_populated_field(): ) +def test_get_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_model] = mock_rpc + request = {} + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1570,6 +1706,50 @@ async def test_get_model_empty_call_async(): assert args[0] == model_service.GetModelRequest() +@pytest.mark.asyncio +async def test_get_model_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model + ] = mock_object + + request = {} + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_async( transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest @@ -1834,6 +2014,9 @@ def test_list_models_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_models() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1860,6 +2043,9 @@ def test_list_models_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_models(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1871,6 +2057,41 @@ def test_list_models_non_empty_request_with_auto_populated_field(): ) +def test_list_models_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_models in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_models] = mock_rpc + request = {} + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_models_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1894,6 +2115,52 @@ async def test_list_models_empty_call_async(): assert args[0] == model_service.ListModelsRequest() +@pytest.mark.asyncio +async def test_list_models_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_models + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_models + ] = mock_object + + request = {} + await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_models_async( transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest @@ -2316,6 +2583,9 @@ def test_list_model_versions_empty_call(): with mock.patch.object( type(client.transport.list_model_versions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_versions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2344,6 +2614,9 @@ def test_list_model_versions_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_model_versions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_versions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2355,6 +2628,45 @@ def test_list_model_versions_non_empty_request_with_auto_populated_field(): ) +def test_list_model_versions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_versions in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_versions + ] = mock_rpc + request = {} + client.list_model_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_versions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2380,6 +2692,52 @@ async def test_list_model_versions_empty_call_async(): assert args[0] == model_service.ListModelVersionsRequest() +@pytest.mark.asyncio +async def test_list_model_versions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_versions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_versions + ] = mock_object + + request = {} + await client.list_model_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_versions_async( transport: str = "grpc_asyncio", request_type=model_service.ListModelVersionsRequest @@ -2852,6 +3210,9 @@ def test_update_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2873,12 +3234,50 @@ def test_update_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UpdateModelRequest() +def test_update_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_model] = mock_rpc + request = {} + client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2922,6 +3321,52 @@ async def test_update_model_empty_call_async(): assert args[0] == model_service.UpdateModelRequest() +@pytest.mark.asyncio +async def test_update_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_model + ] = mock_object + + request = {} + await client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_async( transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest @@ -3197,6 +3642,9 @@ def test_update_explanation_dataset_empty_call(): with mock.patch.object( type(client.transport.update_explanation_dataset), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_explanation_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3222,6 +3670,9 @@ def test_update_explanation_dataset_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.update_explanation_dataset), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_explanation_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3230,6 +3681,50 @@ def test_update_explanation_dataset_non_empty_request_with_auto_populated_field( ) +def test_update_explanation_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_explanation_dataset + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_explanation_dataset + ] = mock_rpc + request = {} + client.update_explanation_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_explanation_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_explanation_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3253,6 +3748,56 @@ async def test_update_explanation_dataset_empty_call_async(): assert args[0] == model_service.UpdateExplanationDatasetRequest() +@pytest.mark.asyncio +async def test_update_explanation_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_explanation_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_explanation_dataset + ] = mock_object + + request = {} + await client.update_explanation_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_explanation_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_explanation_dataset_async( transport: str = "grpc_asyncio", @@ -3486,6 +4031,9 @@ def test_delete_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3509,6 +4057,9 @@ def test_delete_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3517,6 +4068,45 @@ def test_delete_model_non_empty_request_with_auto_populated_field(): ) +def test_delete_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_model] = mock_rpc + request = {} + client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3539,10 +4129,60 @@ async def test_delete_model_empty_call_async(): @pytest.mark.asyncio -async def test_delete_model_async( - transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +async def test_delete_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = ModelServiceAsyncClient( + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_model + ] = mock_object + + request = {} + await client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +): + client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3764,6 +4404,9 @@ def test_delete_model_version_empty_call(): with mock.patch.object( type(client.transport.delete_model_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_version() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3789,6 +4432,9 @@ def test_delete_model_version_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_model_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3797,6 +4443,49 @@ def test_delete_model_version_non_empty_request_with_auto_populated_field(): ) +def test_delete_model_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_version in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_version + ] = mock_rpc + request = {} + client.delete_model_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3820,6 +4509,56 @@ async def test_delete_model_version_empty_call_async(): assert args[0] == model_service.DeleteModelVersionRequest() +@pytest.mark.asyncio +async def test_delete_model_version_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_model_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_model_version + ] = mock_object + + request = {} + await client.delete_model_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_model_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_version_async( transport: str = "grpc_asyncio", @@ -4096,6 +4835,9 @@ def test_merge_version_aliases_empty_call(): with mock.patch.object( type(client.transport.merge_version_aliases), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.merge_version_aliases() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4121,6 +4863,9 @@ def test_merge_version_aliases_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.merge_version_aliases), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.merge_version_aliases(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4129,6 +4874,46 @@ def test_merge_version_aliases_non_empty_request_with_auto_populated_field(): ) +def test_merge_version_aliases_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.merge_version_aliases + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.merge_version_aliases + ] = mock_rpc + request = {} + client.merge_version_aliases(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.merge_version_aliases(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_merge_version_aliases_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4174,6 +4959,52 @@ async def test_merge_version_aliases_empty_call_async(): assert args[0] == model_service.MergeVersionAliasesRequest() +@pytest.mark.asyncio +async def test_merge_version_aliases_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.merge_version_aliases + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.merge_version_aliases + ] = mock_object + + request = {} + await client.merge_version_aliases(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.merge_version_aliases(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_merge_version_aliases_async( transport: str = "grpc_asyncio", @@ -4456,6 +5287,9 @@ def test_export_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4479,6 +5313,9 @@ def test_export_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4487,6 +5324,45 @@ def test_export_model_non_empty_request_with_auto_populated_field(): ) +def test_export_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_model] = mock_rpc + request = {} + client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4508,6 +5384,56 @@ async def test_export_model_empty_call_async(): assert args[0] == model_service.ExportModelRequest() +@pytest.mark.asyncio +async def test_export_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_model + ] = mock_object + + request = {} + await client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.export_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_export_model_async( transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest @@ -4752,6 +5678,9 @@ def test_copy_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.copy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.copy_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4778,6 +5707,9 @@ def test_copy_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.copy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.copy_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4789,6 +5721,45 @@ def test_copy_model_non_empty_request_with_auto_populated_field(): ) +def test_copy_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.copy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.copy_model] = mock_rpc + request = {} + client.copy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.copy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_copy_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4810,6 +5781,54 @@ async def test_copy_model_empty_call_async(): assert args[0] == model_service.CopyModelRequest() +@pytest.mark.asyncio +async def test_copy_model_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.copy_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.copy_model + ] = mock_object + + request = {} + await client.copy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.copy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_copy_model_async( transport: str = "grpc_asyncio", request_type=model_service.CopyModelRequest @@ -5059,6 +6078,9 @@ def test_import_model_evaluation_empty_call(): with mock.patch.object( type(client.transport.import_model_evaluation), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_model_evaluation() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5084,6 +6106,9 @@ def test_import_model_evaluation_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.import_model_evaluation), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_model_evaluation(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5092,6 +6117,46 @@ def test_import_model_evaluation_non_empty_request_with_auto_populated_field(): ) +def test_import_model_evaluation_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.import_model_evaluation + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_model_evaluation + ] = mock_rpc + request = {} + client.import_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.import_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_model_evaluation_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5122,6 +6187,52 @@ async def test_import_model_evaluation_empty_call_async(): assert args[0] == model_service.ImportModelEvaluationRequest() +@pytest.mark.asyncio +async def test_import_model_evaluation_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_model_evaluation + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_model_evaluation + ] = mock_object + + request = {} + await client.import_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.import_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_model_evaluation_async( transport: str = "grpc_asyncio", @@ -5387,6 +6498,9 @@ def test_batch_import_model_evaluation_slices_empty_call(): with mock.patch.object( type(client.transport.batch_import_model_evaluation_slices), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_import_model_evaluation_slices() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5412,6 +6526,9 @@ def test_batch_import_model_evaluation_slices_non_empty_request_with_auto_popula with mock.patch.object( type(client.transport.batch_import_model_evaluation_slices), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_import_model_evaluation_slices(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5420,6 +6537,46 @@ def test_batch_import_model_evaluation_slices_non_empty_request_with_auto_popula ) +def test_batch_import_model_evaluation_slices_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_import_model_evaluation_slices + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_import_model_evaluation_slices + ] = mock_rpc + request = {} + client.batch_import_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_import_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_import_model_evaluation_slices_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5447,6 +6604,52 @@ async def test_batch_import_model_evaluation_slices_empty_call_async(): assert args[0] == model_service.BatchImportModelEvaluationSlicesRequest() +@pytest.mark.asyncio +async def test_batch_import_model_evaluation_slices_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_import_model_evaluation_slices + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_import_model_evaluation_slices + ] = mock_object + + request = {} + await client.batch_import_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_import_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_import_model_evaluation_slices_async( transport: str = "grpc_asyncio", @@ -5712,6 +6915,9 @@ def test_batch_import_evaluated_annotations_empty_call(): with mock.patch.object( type(client.transport.batch_import_evaluated_annotations), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_import_evaluated_annotations() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5737,6 +6943,9 @@ def test_batch_import_evaluated_annotations_non_empty_request_with_auto_populate with mock.patch.object( type(client.transport.batch_import_evaluated_annotations), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_import_evaluated_annotations(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5745,6 +6954,46 @@ def test_batch_import_evaluated_annotations_non_empty_request_with_auto_populate ) +def test_batch_import_evaluated_annotations_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_import_evaluated_annotations + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_import_evaluated_annotations + ] = mock_rpc + request = {} + client.batch_import_evaluated_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_import_evaluated_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_import_evaluated_annotations_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5770,6 +7019,52 @@ async def test_batch_import_evaluated_annotations_empty_call_async(): assert args[0] == model_service.BatchImportEvaluatedAnnotationsRequest() +@pytest.mark.asyncio +async def test_batch_import_evaluated_annotations_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_import_evaluated_annotations + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_import_evaluated_annotations + ] = mock_object + + request = {} + await client.batch_import_evaluated_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_import_evaluated_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_import_evaluated_annotations_async( transport: str = "grpc_asyncio", @@ -6057,6 +7352,9 @@ def test_get_model_evaluation_empty_call(): with mock.patch.object( type(client.transport.get_model_evaluation), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_evaluation() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6082,6 +7380,9 @@ def test_get_model_evaluation_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_model_evaluation), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_evaluation(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6090,6 +7391,45 @@ def test_get_model_evaluation_non_empty_request_with_auto_populated_field(): ) +def test_get_model_evaluation_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_evaluation in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_evaluation + ] = mock_rpc + request = {} + client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_evaluation_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6120,6 +7460,52 @@ async def test_get_model_evaluation_empty_call_async(): assert args[0] == model_service.GetModelEvaluationRequest() +@pytest.mark.asyncio +async def test_get_model_evaluation_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model_evaluation + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model_evaluation + ] = mock_object + + request = {} + await client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_evaluation_async( transport: str = "grpc_asyncio", @@ -6373,6 +7759,9 @@ def test_list_model_evaluations_empty_call(): with mock.patch.object( type(client.transport.list_model_evaluations), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_evaluations() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6400,6 +7789,9 @@ def test_list_model_evaluations_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_model_evaluations), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_evaluations(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6410,6 +7802,46 @@ def test_list_model_evaluations_non_empty_request_with_auto_populated_field(): ) +def test_list_model_evaluations_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_evaluations + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_evaluations + ] = mock_rpc + request = {} + client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_evaluations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_evaluations_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6435,6 +7867,52 @@ async def test_list_model_evaluations_empty_call_async(): assert args[0] == model_service.ListModelEvaluationsRequest() +@pytest.mark.asyncio +async def test_list_model_evaluations_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_evaluations + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_evaluations + ] = mock_object + + request = {} + await client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_evaluations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_evaluations_async( transport: str = "grpc_asyncio", @@ -6878,6 +8356,9 @@ def test_get_model_evaluation_slice_empty_call(): with mock.patch.object( type(client.transport.get_model_evaluation_slice), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_evaluation_slice() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6903,6 +8384,9 @@ def test_get_model_evaluation_slice_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.get_model_evaluation_slice), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_evaluation_slice(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6911,6 +8395,46 @@ def test_get_model_evaluation_slice_non_empty_request_with_auto_populated_field( ) +def test_get_model_evaluation_slice_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_evaluation_slice + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_evaluation_slice + ] = mock_rpc + request = {} + client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_evaluation_slice(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_evaluation_slice_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6937,6 +8461,52 @@ async def test_get_model_evaluation_slice_empty_call_async(): assert args[0] == model_service.GetModelEvaluationSliceRequest() +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model_evaluation_slice + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model_evaluation_slice + ] = mock_object + + request = {} + await client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model_evaluation_slice(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_evaluation_slice_async( transport: str = "grpc_asyncio", @@ -7182,6 +8752,9 @@ def test_list_model_evaluation_slices_empty_call(): with mock.patch.object( type(client.transport.list_model_evaluation_slices), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_evaluation_slices() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7209,6 +8782,9 @@ def test_list_model_evaluation_slices_non_empty_request_with_auto_populated_fiel with mock.patch.object( type(client.transport.list_model_evaluation_slices), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_evaluation_slices(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7219,6 +8795,46 @@ def test_list_model_evaluation_slices_non_empty_request_with_auto_populated_fiel ) +def test_list_model_evaluation_slices_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_evaluation_slices + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_evaluation_slices + ] = mock_rpc + request = {} + client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_evaluation_slices_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7244,6 +8860,52 @@ async def test_list_model_evaluation_slices_empty_call_async(): assert args[0] == model_service.ListModelEvaluationSlicesRequest() +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_evaluation_slices + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_evaluation_slices + ] = mock_object + + request = {} + await client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async( transport: str = "grpc_asyncio", @@ -7675,6 +9337,46 @@ def test_upload_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_upload_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upload_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.upload_model] = mock_rpc + + request = {} + client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.upload_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_upload_model_rest_required_fields( request_type=model_service.UploadModelRequest, ): @@ -7986,6 +9688,42 @@ def test_get_model_rest(request_type): assert response.metadata_artifact == "metadata_artifact_value" +def test_get_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_model] = mock_rpc + + request = {} + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_rest_required_fields(request_type=model_service.GetModelRequest): transport_class = transports.ModelServiceRestTransport @@ -8246,6 +9984,42 @@ def test_list_models_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_models_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_models in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_models] = mock_rpc + + request = {} + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_models_rest_required_fields(request_type=model_service.ListModelsRequest): transport_class = transports.ModelServiceRestTransport @@ -8577,13 +10351,53 @@ def test_list_model_versions_rest(request_type): return_value = model_service.ListModelVersionsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.list_model_versions(request) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_model_versions(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelVersionsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_model_versions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_versions in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_versions + ] = mock_rpc + + request = {} + client.list_model_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_versions(request) - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelVersionsPager) - assert response.next_page_token == "next_page_token_value" + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_list_model_versions_rest_required_fields( @@ -9155,6 +10969,42 @@ def get_message_fields(field): assert response.metadata_artifact == "metadata_artifact_value" +def test_update_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_model] = mock_rpc + + request = {} + client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_model_rest_required_fields( request_type=model_service.UpdateModelRequest, ): @@ -9427,6 +11277,51 @@ def test_update_explanation_dataset_rest(request_type): assert response.operation.name == "operations/spam" +def test_update_explanation_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_explanation_dataset + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_explanation_dataset + ] = mock_rpc + + request = {} + client.update_explanation_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_explanation_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_explanation_dataset_rest_required_fields( request_type=model_service.UpdateExplanationDatasetRequest, ): @@ -9687,6 +11582,46 @@ def test_delete_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_model] = mock_rpc + + request = {} + client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_model_rest_required_fields( request_type=model_service.DeleteModelRequest, ): @@ -9945,6 +11880,50 @@ def test_delete_model_version_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_model_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_version in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_version + ] = mock_rpc + + request = {} + client.delete_model_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_model_version_rest_required_fields( request_type=model_service.DeleteModelVersionRequest, ): @@ -10245,6 +12224,47 @@ def test_merge_version_aliases_rest(request_type): assert response.metadata_artifact == "metadata_artifact_value" +def test_merge_version_aliases_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.merge_version_aliases + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.merge_version_aliases + ] = mock_rpc + + request = {} + client.merge_version_aliases(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.merge_version_aliases(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_merge_version_aliases_rest_required_fields( request_type=model_service.MergeVersionAliasesRequest, ): @@ -10520,6 +12540,46 @@ def test_export_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_export_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_model] = mock_rpc + + request = {} + client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_export_model_rest_required_fields( request_type=model_service.ExportModelRequest, ): @@ -10794,6 +12854,46 @@ def test_copy_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_copy_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.copy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.copy_model] = mock_rpc + + request = {} + client.copy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.copy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_copy_model_rest_required_fields(request_type=model_service.CopyModelRequest): transport_class = transports.ModelServiceRestTransport @@ -11079,6 +13179,47 @@ def test_import_model_evaluation_rest(request_type): assert response.annotation_schema_uri == "annotation_schema_uri_value" +def test_import_model_evaluation_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.import_model_evaluation + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_model_evaluation + ] = mock_rpc + + request = {} + client.import_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.import_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_import_model_evaluation_rest_required_fields( request_type=model_service.ImportModelEvaluationRequest, ): @@ -11363,6 +13504,47 @@ def test_batch_import_model_evaluation_slices_rest(request_type): ] +def test_batch_import_model_evaluation_slices_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_import_model_evaluation_slices + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_import_model_evaluation_slices + ] = mock_rpc + + request = {} + client.batch_import_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_import_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_import_model_evaluation_slices_rest_required_fields( request_type=model_service.BatchImportModelEvaluationSlicesRequest, ): @@ -11666,6 +13848,47 @@ def test_batch_import_evaluated_annotations_rest(request_type): assert response.imported_evaluated_annotations_count == 3859 +def test_batch_import_evaluated_annotations_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_import_evaluated_annotations + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_import_evaluated_annotations + ] = mock_rpc + + request = {} + client.batch_import_evaluated_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_import_evaluated_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_import_evaluated_annotations_rest_required_fields( request_type=model_service.BatchImportEvaluatedAnnotationsRequest, ): @@ -11980,6 +14203,46 @@ def test_get_model_evaluation_rest(request_type): assert response.annotation_schema_uri == "annotation_schema_uri_value" +def test_get_model_evaluation_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_evaluation in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_evaluation + ] = mock_rpc + + request = {} + client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_evaluation_rest_required_fields( request_type=model_service.GetModelEvaluationRequest, ): @@ -12251,6 +14514,47 @@ def test_list_model_evaluations_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_model_evaluations_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_evaluations + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_evaluations + ] = mock_rpc + + request = {} + client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_evaluations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_model_evaluations_rest_required_fields( request_type=model_service.ListModelEvaluationsRequest, ): @@ -12598,6 +14902,47 @@ def test_get_model_evaluation_slice_rest(request_type): assert response.metrics_schema_uri == "metrics_schema_uri_value" +def test_get_model_evaluation_slice_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_evaluation_slice + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_evaluation_slice + ] = mock_rpc + + request = {} + client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_evaluation_slice(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_evaluation_slice_rest_required_fields( request_type=model_service.GetModelEvaluationSliceRequest, ): @@ -12871,6 +15216,47 @@ def test_list_model_evaluation_slices_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_model_evaluation_slices_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_evaluation_slices + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_evaluation_slices + ] = mock_rpc + + request = {} + client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_model_evaluation_slices_rest_required_fields( request_type=model_service.ListModelEvaluationSlicesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_notebook_service.py b/tests/unit/gapic/aiplatform_v1/test_notebook_service.py index c6b3d5384b..05d9126251 100644 --- a/tests/unit/gapic/aiplatform_v1/test_notebook_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_notebook_service.py @@ -1214,6 +1214,9 @@ def test_create_notebook_runtime_template_empty_call(): with mock.patch.object( type(client.transport.create_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_notebook_runtime_template() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1240,6 +1243,9 @@ def test_create_notebook_runtime_template_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.create_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_notebook_runtime_template(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1249,6 +1255,50 @@ def test_create_notebook_runtime_template_non_empty_request_with_auto_populated_ ) +def test_create_notebook_runtime_template_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_notebook_runtime_template + ] = mock_rpc + request = {} + client.create_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_notebook_runtime_template_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1272,6 +1322,56 @@ async def test_create_notebook_runtime_template_empty_call_async(): assert args[0] == notebook_service.CreateNotebookRuntimeTemplateRequest() +@pytest.mark.asyncio +async def test_create_notebook_runtime_template_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_notebook_runtime_template + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_notebook_runtime_template + ] = mock_object + + request = {} + await client.create_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_notebook_runtime_template_async( transport: str = "grpc_asyncio", @@ -1557,6 +1657,9 @@ def test_get_notebook_runtime_template_empty_call(): with mock.patch.object( type(client.transport.get_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_notebook_runtime_template() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1582,6 +1685,9 @@ def test_get_notebook_runtime_template_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.get_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_notebook_runtime_template(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1590,6 +1696,46 @@ def test_get_notebook_runtime_template_non_empty_request_with_auto_populated_fie ) +def test_get_notebook_runtime_template_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_runtime_template + ] = mock_rpc + request = {} + client.get_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_notebook_runtime_template_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1622,6 +1768,52 @@ async def test_get_notebook_runtime_template_empty_call_async(): assert args[0] == notebook_service.GetNotebookRuntimeTemplateRequest() +@pytest.mark.asyncio +async def test_get_notebook_runtime_template_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_notebook_runtime_template + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_notebook_runtime_template + ] = mock_object + + request = {} + await client.get_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_notebook_runtime_template_async( transport: str = "grpc_asyncio", @@ -1882,6 +2074,9 @@ def test_list_notebook_runtime_templates_empty_call(): with mock.patch.object( type(client.transport.list_notebook_runtime_templates), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_notebook_runtime_templates() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1910,6 +2105,9 @@ def test_list_notebook_runtime_templates_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.list_notebook_runtime_templates), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_notebook_runtime_templates(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1921,6 +2119,46 @@ def test_list_notebook_runtime_templates_non_empty_request_with_auto_populated_f ) +def test_list_notebook_runtime_templates_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_runtime_templates + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_runtime_templates + ] = mock_rpc + request = {} + client.list_notebook_runtime_templates(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_runtime_templates(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_notebook_runtime_templates_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1946,6 +2184,52 @@ async def test_list_notebook_runtime_templates_empty_call_async(): assert args[0] == notebook_service.ListNotebookRuntimeTemplatesRequest() +@pytest.mark.asyncio +async def test_list_notebook_runtime_templates_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_notebook_runtime_templates + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_notebook_runtime_templates + ] = mock_object + + request = {} + await client.list_notebook_runtime_templates(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_notebook_runtime_templates(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_notebook_runtime_templates_async( transport: str = "grpc_asyncio", @@ -2388,6 +2672,9 @@ def test_delete_notebook_runtime_template_empty_call(): with mock.patch.object( type(client.transport.delete_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_notebook_runtime_template() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2413,6 +2700,9 @@ def test_delete_notebook_runtime_template_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.delete_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_notebook_runtime_template(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2421,6 +2711,50 @@ def test_delete_notebook_runtime_template_non_empty_request_with_auto_populated_ ) +def test_delete_notebook_runtime_template_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_runtime_template + ] = mock_rpc + request = {} + client.delete_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_notebook_runtime_template_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2444,6 +2778,56 @@ async def test_delete_notebook_runtime_template_empty_call_async(): assert args[0] == notebook_service.DeleteNotebookRuntimeTemplateRequest() +@pytest.mark.asyncio +async def test_delete_notebook_runtime_template_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_notebook_runtime_template + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_notebook_runtime_template + ] = mock_object + + request = {} + await client.delete_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_notebook_runtime_template_async( transport: str = "grpc_asyncio", @@ -2681,6 +3065,9 @@ def test_assign_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.assign_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.assign_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2708,6 +3095,9 @@ def test_assign_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.assign_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.assign_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2718,6 +3108,50 @@ def test_assign_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_assign_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.assign_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.assign_notebook_runtime + ] = mock_rpc + request = {} + client.assign_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.assign_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_assign_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2741,6 +3175,56 @@ async def test_assign_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.AssignNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_assign_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.assign_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.assign_notebook_runtime + ] = mock_object + + request = {} + await client.assign_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.assign_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_assign_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -3038,6 +3522,9 @@ def test_get_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.get_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3063,6 +3550,9 @@ def test_get_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3071,6 +3561,45 @@ def test_get_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_get_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_runtime in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_runtime + ] = mock_rpc + request = {} + client.get_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3107,6 +3636,52 @@ async def test_get_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.GetNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_get_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_notebook_runtime + ] = mock_object + + request = {} + await client.get_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -3377,6 +3952,9 @@ def test_list_notebook_runtimes_empty_call(): with mock.patch.object( type(client.transport.list_notebook_runtimes), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_notebook_runtimes() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3405,6 +3983,9 @@ def test_list_notebook_runtimes_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_notebook_runtimes), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_notebook_runtimes(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3416,6 +3997,46 @@ def test_list_notebook_runtimes_non_empty_request_with_auto_populated_field(): ) +def test_list_notebook_runtimes_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_runtimes + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_runtimes + ] = mock_rpc + request = {} + client.list_notebook_runtimes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_runtimes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_notebook_runtimes_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3441,6 +4062,52 @@ async def test_list_notebook_runtimes_empty_call_async(): assert args[0] == notebook_service.ListNotebookRuntimesRequest() +@pytest.mark.asyncio +async def test_list_notebook_runtimes_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_notebook_runtimes + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_notebook_runtimes + ] = mock_object + + request = {} + await client.list_notebook_runtimes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_notebook_runtimes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_notebook_runtimes_async( transport: str = "grpc_asyncio", @@ -3879,6 +4546,9 @@ def test_delete_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.delete_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3904,6 +4574,9 @@ def test_delete_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3912,6 +4585,50 @@ def test_delete_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_delete_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_runtime + ] = mock_rpc + request = {} + client.delete_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3935,6 +4652,56 @@ async def test_delete_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.DeleteNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_delete_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_notebook_runtime + ] = mock_object + + request = {} + await client.delete_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -4172,6 +4939,9 @@ def test_upgrade_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.upgrade_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upgrade_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4197,6 +4967,9 @@ def test_upgrade_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.upgrade_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upgrade_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4205,6 +4978,50 @@ def test_upgrade_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_upgrade_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.upgrade_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.upgrade_notebook_runtime + ] = mock_rpc + request = {} + client.upgrade_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.upgrade_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_upgrade_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4228,6 +5045,56 @@ async def test_upgrade_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.UpgradeNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_upgrade_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.upgrade_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.upgrade_notebook_runtime + ] = mock_object + + request = {} + await client.upgrade_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.upgrade_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_upgrade_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -4465,6 +5332,9 @@ def test_start_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.start_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.start_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4490,6 +5360,9 @@ def test_start_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.start_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.start_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4498,6 +5371,50 @@ def test_start_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_start_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.start_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.start_notebook_runtime + ] = mock_rpc + request = {} + client.start_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.start_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_start_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4521,6 +5438,56 @@ async def test_start_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.StartNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_start_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.start_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.start_notebook_runtime + ] = mock_object + + request = {} + await client.start_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.start_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_start_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -4853,6 +5820,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_notebook_runtime_template_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_notebook_runtime_template + ] = mock_rpc + + request = {} + client.create_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_notebook_runtime_template_rest_required_fields( request_type=notebook_service.CreateNotebookRuntimeTemplateRequest, ): @@ -5160,6 +6172,47 @@ def test_get_notebook_runtime_template_rest(request_type): assert response.network_tags == ["network_tags_value"] +def test_get_notebook_runtime_template_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_runtime_template + ] = mock_rpc + + request = {} + client.get_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_notebook_runtime_template_rest_required_fields( request_type=notebook_service.GetNotebookRuntimeTemplateRequest, ): @@ -5436,6 +6489,47 @@ def test_list_notebook_runtime_templates_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_notebook_runtime_templates_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_runtime_templates + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_runtime_templates + ] = mock_rpc + + request = {} + client.list_notebook_runtime_templates(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_runtime_templates(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_notebook_runtime_templates_rest_required_fields( request_type=notebook_service.ListNotebookRuntimeTemplatesRequest, ): @@ -5793,6 +6887,51 @@ def test_delete_notebook_runtime_template_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_notebook_runtime_template_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_runtime_template + ] = mock_rpc + + request = {} + client.delete_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_notebook_runtime_template_rest_required_fields( request_type=notebook_service.DeleteNotebookRuntimeTemplateRequest, ): @@ -6061,6 +7200,51 @@ def test_assign_notebook_runtime_rest(request_type): assert response.operation.name == "operations/spam" +def test_assign_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.assign_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.assign_notebook_runtime + ] = mock_rpc + + request = {} + client.assign_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.assign_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_assign_notebook_runtime_rest_required_fields( request_type=notebook_service.AssignNotebookRuntimeRequest, ): @@ -6377,6 +7561,46 @@ def test_get_notebook_runtime_rest(request_type): assert response.network_tags == ["network_tags_value"] +def test_get_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_runtime in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_runtime + ] = mock_rpc + + request = {} + client.get_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_notebook_runtime_rest_required_fields( request_type=notebook_service.GetNotebookRuntimeRequest, ): @@ -6648,6 +7872,47 @@ def test_list_notebook_runtimes_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_notebook_runtimes_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_runtimes + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_runtimes + ] = mock_rpc + + request = {} + client.list_notebook_runtimes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_runtimes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_notebook_runtimes_rest_required_fields( request_type=notebook_service.ListNotebookRuntimesRequest, ): @@ -6994,6 +8259,51 @@ def test_delete_notebook_runtime_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_runtime + ] = mock_rpc + + request = {} + client.delete_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_notebook_runtime_rest_required_fields( request_type=notebook_service.DeleteNotebookRuntimeRequest, ): @@ -7259,6 +8569,51 @@ def test_upgrade_notebook_runtime_rest(request_type): assert response.operation.name == "operations/spam" +def test_upgrade_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.upgrade_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.upgrade_notebook_runtime + ] = mock_rpc + + request = {} + client.upgrade_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.upgrade_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_upgrade_notebook_runtime_rest_required_fields( request_type=notebook_service.UpgradeNotebookRuntimeRequest, ): @@ -7525,6 +8880,51 @@ def test_start_notebook_runtime_rest(request_type): assert response.operation.name == "operations/spam" +def test_start_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.start_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.start_notebook_runtime + ] = mock_rpc + + request = {} + client.start_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.start_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_start_notebook_runtime_rest_required_fields( request_type=notebook_service.StartNotebookRuntimeRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py b/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py index 9ff55ff500..81766ebc2d 100644 --- a/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py @@ -1276,6 +1276,9 @@ def test_create_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.create_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1302,6 +1305,9 @@ def test_create_persistent_resource_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.create_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1311,6 +1317,50 @@ def test_create_persistent_resource_non_empty_request_with_auto_populated_field( ) +def test_create_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_persistent_resource + ] = mock_rpc + request = {} + client.create_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1334,6 +1384,56 @@ async def test_create_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.CreatePersistentResourceRequest() +@pytest.mark.asyncio +async def test_create_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_persistent_resource + ] = mock_object + + request = {} + await client.create_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_persistent_resource_async( transport: str = "grpc_asyncio", @@ -1610,6 +1710,9 @@ def test_get_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.get_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1635,6 +1738,9 @@ def test_get_persistent_resource_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1643,6 +1749,46 @@ def test_get_persistent_resource_non_empty_request_with_auto_populated_field(): ) +def test_get_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_persistent_resource + ] = mock_rpc + request = {} + client.get_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1672,6 +1818,52 @@ async def test_get_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.GetPersistentResourceRequest() +@pytest.mark.asyncio +async def test_get_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_persistent_resource + ] = mock_object + + request = {} + await client.get_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_persistent_resource_async( transport: str = "grpc_asyncio", @@ -1923,6 +2115,9 @@ def test_list_persistent_resources_empty_call(): with mock.patch.object( type(client.transport.list_persistent_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_persistent_resources() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1949,6 +2144,9 @@ def test_list_persistent_resources_non_empty_request_with_auto_populated_field() with mock.patch.object( type(client.transport.list_persistent_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_persistent_resources(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1958,6 +2156,46 @@ def test_list_persistent_resources_non_empty_request_with_auto_populated_field() ) +def test_list_persistent_resources_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_persistent_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_persistent_resources + ] = mock_rpc + request = {} + client.list_persistent_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_persistent_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_persistent_resources_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1983,6 +2221,52 @@ async def test_list_persistent_resources_empty_call_async(): assert args[0] == persistent_resource_service.ListPersistentResourcesRequest() +@pytest.mark.asyncio +async def test_list_persistent_resources_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_persistent_resources + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_persistent_resources + ] = mock_object + + request = {} + await client.list_persistent_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_persistent_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_persistent_resources_async( transport: str = "grpc_asyncio", @@ -2431,6 +2715,9 @@ def test_delete_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.delete_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2456,6 +2743,9 @@ def test_delete_persistent_resource_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.delete_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2464,6 +2754,50 @@ def test_delete_persistent_resource_non_empty_request_with_auto_populated_field( ) +def test_delete_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_persistent_resource + ] = mock_rpc + request = {} + client.delete_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2487,6 +2821,56 @@ async def test_delete_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.DeletePersistentResourceRequest() +@pytest.mark.asyncio +async def test_delete_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_persistent_resource + ] = mock_object + + request = {} + await client.delete_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_persistent_resource_async( transport: str = "grpc_asyncio", @@ -2724,6 +3108,9 @@ def test_update_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.update_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2747,12 +3134,59 @@ def test_update_persistent_resource_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.update_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == persistent_resource_service.UpdatePersistentResourceRequest() +def test_update_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_persistent_resource + ] = mock_rpc + request = {} + client.update_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2776,6 +3210,56 @@ async def test_update_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.UpdatePersistentResourceRequest() +@pytest.mark.asyncio +async def test_update_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_persistent_resource + ] = mock_object + + request = {} + await client.update_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_persistent_resource_async( transport: str = "grpc_asyncio", @@ -3031,6 +3515,9 @@ def test_reboot_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.reboot_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.reboot_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3056,6 +3543,9 @@ def test_reboot_persistent_resource_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.reboot_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.reboot_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3064,6 +3554,50 @@ def test_reboot_persistent_resource_non_empty_request_with_auto_populated_field( ) +def test_reboot_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.reboot_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.reboot_persistent_resource + ] = mock_rpc + request = {} + client.reboot_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.reboot_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_reboot_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3087,6 +3621,56 @@ async def test_reboot_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.RebootPersistentResourceRequest() +@pytest.mark.asyncio +async def test_reboot_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.reboot_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.reboot_persistent_resource + ] = mock_object + + request = {} + await client.reboot_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.reboot_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_reboot_persistent_resource_async( transport: str = "grpc_asyncio", @@ -3437,6 +4021,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_persistent_resource + ] = mock_rpc + + request = {} + client.create_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_persistent_resource_rest_required_fields( request_type=persistent_resource_service.CreatePersistentResourceRequest, ): @@ -3749,6 +4378,47 @@ def test_get_persistent_resource_rest(request_type): assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"] +def test_get_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_persistent_resource + ] = mock_rpc + + request = {} + client.get_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_persistent_resource_rest_required_fields( request_type=persistent_resource_service.GetPersistentResourceRequest, ): @@ -4025,6 +4695,47 @@ def test_list_persistent_resources_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_persistent_resources_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_persistent_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_persistent_resources + ] = mock_rpc + + request = {} + client.list_persistent_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_persistent_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_persistent_resources_rest_required_fields( request_type=persistent_resource_service.ListPersistentResourcesRequest, ): @@ -4377,6 +5088,51 @@ def test_delete_persistent_resource_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_persistent_resource + ] = mock_rpc + + request = {} + client.delete_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_persistent_resource_rest_required_fields( request_type=persistent_resource_service.DeletePersistentResourceRequest, ): @@ -4772,6 +5528,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_persistent_resource + ] = mock_rpc + + request = {} + client.update_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_persistent_resource_rest_required_fields( request_type=persistent_resource_service.UpdatePersistentResourceRequest, ): @@ -5056,6 +5857,51 @@ def test_reboot_persistent_resource_rest(request_type): assert response.operation.name == "operations/spam" +def test_reboot_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.reboot_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.reboot_persistent_resource + ] = mock_rpc + + request = {} + client.reboot_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.reboot_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_reboot_persistent_resource_rest_required_fields( request_type=persistent_resource_service.RebootPersistentResourceRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index 3f1e2af8e8..0bc897a722 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -1239,6 +1239,9 @@ def test_create_training_pipeline_empty_call(): with mock.patch.object( type(client.transport.create_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1264,6 +1267,9 @@ def test_create_training_pipeline_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_training_pipeline(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1272,6 +1278,46 @@ def test_create_training_pipeline_non_empty_request_with_auto_populated_field(): ) +def test_create_training_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_training_pipeline + ] = mock_rpc + request = {} + client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_training_pipeline_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1302,6 +1348,52 @@ async def test_create_training_pipeline_empty_call_async(): assert args[0] == pipeline_service.CreateTrainingPipelineRequest() +@pytest.mark.asyncio +async def test_create_training_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_training_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_training_pipeline + ] = mock_object + + request = {} + await client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_training_pipeline_async( transport: str = "grpc_asyncio", @@ -1575,6 +1667,9 @@ def test_get_training_pipeline_empty_call(): with mock.patch.object( type(client.transport.get_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1600,6 +1695,9 @@ def test_get_training_pipeline_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_training_pipeline(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1608,6 +1706,46 @@ def test_get_training_pipeline_non_empty_request_with_auto_populated_field(): ) +def test_get_training_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_training_pipeline + ] = mock_rpc + request = {} + client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_training_pipeline_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1638,6 +1776,52 @@ async def test_get_training_pipeline_empty_call_async(): assert args[0] == pipeline_service.GetTrainingPipelineRequest() +@pytest.mark.asyncio +async def test_get_training_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_training_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_training_pipeline + ] = mock_object + + request = {} + await client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_training_pipeline_async( transport: str = "grpc_asyncio", @@ -1891,6 +2075,9 @@ def test_list_training_pipelines_empty_call(): with mock.patch.object( type(client.transport.list_training_pipelines), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_training_pipelines() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1918,6 +2105,9 @@ def test_list_training_pipelines_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_training_pipelines), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_training_pipelines(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1928,6 +2118,46 @@ def test_list_training_pipelines_non_empty_request_with_auto_populated_field(): ) +def test_list_training_pipelines_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_training_pipelines + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_training_pipelines + ] = mock_rpc + request = {} + client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_training_pipelines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_training_pipelines_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1953,6 +2183,52 @@ async def test_list_training_pipelines_empty_call_async(): assert args[0] == pipeline_service.ListTrainingPipelinesRequest() +@pytest.mark.asyncio +async def test_list_training_pipelines_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_training_pipelines + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_training_pipelines + ] = mock_object + + request = {} + await client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_training_pipelines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_training_pipelines_async( transport: str = "grpc_asyncio", @@ -2391,6 +2667,9 @@ def test_delete_training_pipeline_empty_call(): with mock.patch.object( type(client.transport.delete_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2416,6 +2695,9 @@ def test_delete_training_pipeline_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_training_pipeline(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2424,6 +2706,50 @@ def test_delete_training_pipeline_non_empty_request_with_auto_populated_field(): ) +def test_delete_training_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_training_pipeline + ] = mock_rpc + request = {} + client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_training_pipeline_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2447,6 +2773,56 @@ async def test_delete_training_pipeline_empty_call_async(): assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() +@pytest.mark.asyncio +async def test_delete_training_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_training_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_training_pipeline + ] = mock_object + + request = {} + await client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_training_pipeline_async( transport: str = "grpc_asyncio", @@ -2684,6 +3060,9 @@ def test_cancel_training_pipeline_empty_call(): with mock.patch.object( type(client.transport.cancel_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2709,6 +3088,9 @@ def test_cancel_training_pipeline_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_training_pipeline(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2717,6 +3099,46 @@ def test_cancel_training_pipeline_non_empty_request_with_auto_populated_field(): ) +def test_cancel_training_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_training_pipeline + ] = mock_rpc + request = {} + client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_training_pipeline_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2738,6 +3160,52 @@ async def test_cancel_training_pipeline_empty_call_async(): assert args[0] == pipeline_service.CancelTrainingPipelineRequest() +@pytest.mark.asyncio +async def test_cancel_training_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_training_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_training_pipeline + ] = mock_object + + request = {} + await client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_training_pipeline_async( transport: str = "grpc_asyncio", @@ -2986,6 +3454,9 @@ def test_create_pipeline_job_empty_call(): with mock.patch.object( type(client.transport.create_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_pipeline_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3012,6 +3483,9 @@ def test_create_pipeline_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_pipeline_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3021,6 +3495,45 @@ def test_create_pipeline_job_non_empty_request_with_auto_populated_field(): ) +def test_create_pipeline_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_pipeline_job + ] = mock_rpc + request = {} + client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_pipeline_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3053,6 +3566,52 @@ async def test_create_pipeline_job_empty_call_async(): assert args[0] == pipeline_service.CreatePipelineJobRequest() +@pytest.mark.asyncio +async def test_create_pipeline_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_pipeline_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_pipeline_job + ] = mock_object + + request = {} + await client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_pipeline_job_async( transport: str = "grpc_asyncio", @@ -3340,6 +3899,9 @@ def test_get_pipeline_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_pipeline_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3363,6 +3925,9 @@ def test_get_pipeline_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_pipeline_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3371,6 +3936,43 @@ def test_get_pipeline_job_non_empty_request_with_auto_populated_field(): ) +def test_get_pipeline_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_pipeline_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_pipeline_job + ] = mock_rpc + request = {} + client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_pipeline_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3401,6 +4003,52 @@ async def test_get_pipeline_job_empty_call_async(): assert args[0] == pipeline_service.GetPipelineJobRequest() +@pytest.mark.asyncio +async def test_get_pipeline_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_pipeline_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_pipeline_job + ] = mock_object + + request = {} + await client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_pipeline_job_async( transport: str = "grpc_asyncio", request_type=pipeline_service.GetPipelineJobRequest @@ -3647,6 +4295,9 @@ def test_list_pipeline_jobs_empty_call(): with mock.patch.object( type(client.transport.list_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_pipeline_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3675,6 +4326,9 @@ def test_list_pipeline_jobs_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_pipeline_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3686,6 +4340,45 @@ def test_list_pipeline_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_pipeline_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_pipeline_jobs in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_pipeline_jobs + ] = mock_rpc + request = {} + client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_pipeline_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3711,6 +4404,52 @@ async def test_list_pipeline_jobs_empty_call_async(): assert args[0] == pipeline_service.ListPipelineJobsRequest() +@pytest.mark.asyncio +async def test_list_pipeline_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_pipeline_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_pipeline_jobs + ] = mock_object + + request = {} + await client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_pipeline_jobs_async( transport: str = "grpc_asyncio", @@ -4149,6 +4888,9 @@ def test_delete_pipeline_job_empty_call(): with mock.patch.object( type(client.transport.delete_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_pipeline_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4174,6 +4916,9 @@ def test_delete_pipeline_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_pipeline_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4182,6 +4927,49 @@ def test_delete_pipeline_job_non_empty_request_with_auto_populated_field(): ) +def test_delete_pipeline_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_pipeline_job + ] = mock_rpc + request = {} + client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_pipeline_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4205,6 +4993,56 @@ async def test_delete_pipeline_job_empty_call_async(): assert args[0] == pipeline_service.DeletePipelineJobRequest() +@pytest.mark.asyncio +async def test_delete_pipeline_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_pipeline_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_pipeline_job + ] = mock_object + + request = {} + await client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_pipeline_job_async( transport: str = "grpc_asyncio", @@ -4442,6 +5280,9 @@ def test_batch_delete_pipeline_jobs_empty_call(): with mock.patch.object( type(client.transport.batch_delete_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_delete_pipeline_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4467,6 +5308,9 @@ def test_batch_delete_pipeline_jobs_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.batch_delete_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_delete_pipeline_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4475,6 +5319,50 @@ def test_batch_delete_pipeline_jobs_non_empty_request_with_auto_populated_field( ) +def test_batch_delete_pipeline_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_delete_pipeline_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_delete_pipeline_jobs + ] = mock_rpc + request = {} + client.batch_delete_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_delete_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_delete_pipeline_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4498,6 +5386,56 @@ async def test_batch_delete_pipeline_jobs_empty_call_async(): assert args[0] == pipeline_service.BatchDeletePipelineJobsRequest() +@pytest.mark.asyncio +async def test_batch_delete_pipeline_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_delete_pipeline_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_delete_pipeline_jobs + ] = mock_object + + request = {} + await client.batch_delete_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_delete_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_delete_pipeline_jobs_async( transport: str = "grpc_asyncio", @@ -4745,6 +5683,9 @@ def test_cancel_pipeline_job_empty_call(): with mock.patch.object( type(client.transport.cancel_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_pipeline_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4770,6 +5711,9 @@ def test_cancel_pipeline_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_pipeline_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4778,6 +5722,45 @@ def test_cancel_pipeline_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_pipeline_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_pipeline_job + ] = mock_rpc + request = {} + client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_pipeline_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4799,6 +5782,52 @@ async def test_cancel_pipeline_job_empty_call_async(): assert args[0] == pipeline_service.CancelPipelineJobRequest() +@pytest.mark.asyncio +async def test_cancel_pipeline_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_pipeline_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_pipeline_job + ] = mock_object + + request = {} + await client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_pipeline_job_async( transport: str = "grpc_asyncio", @@ -5030,6 +6059,9 @@ def test_batch_cancel_pipeline_jobs_empty_call(): with mock.patch.object( type(client.transport.batch_cancel_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_cancel_pipeline_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5051,16 +6083,63 @@ def test_batch_cancel_pipeline_jobs_non_empty_request_with_auto_populated_field( parent="parent_value", ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_cancel_pipeline_jobs), "__call__" - ) as call: - client.batch_cancel_pipeline_jobs(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == pipeline_service.BatchCancelPipelineJobsRequest( - parent="parent_value", + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_cancel_pipeline_jobs), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.batch_cancel_pipeline_jobs(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == pipeline_service.BatchCancelPipelineJobsRequest( + parent="parent_value", + ) + + +def test_batch_cancel_pipeline_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_cancel_pipeline_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) + client._transport._wrapped_methods[ + client._transport.batch_cancel_pipeline_jobs + ] = mock_rpc + request = {} + client.batch_cancel_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_cancel_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 @pytest.mark.asyncio @@ -5086,6 +6165,56 @@ async def test_batch_cancel_pipeline_jobs_empty_call_async(): assert args[0] == pipeline_service.BatchCancelPipelineJobsRequest() +@pytest.mark.asyncio +async def test_batch_cancel_pipeline_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_cancel_pipeline_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_cancel_pipeline_jobs + ] = mock_object + + request = {} + await client.batch_cancel_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_cancel_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_cancel_pipeline_jobs_async( transport: str = "grpc_asyncio", @@ -5584,6 +6713,47 @@ def get_message_fields(field): assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED +def test_create_training_pipeline_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_training_pipeline + ] = mock_rpc + + request = {} + client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_training_pipeline_rest_required_fields( request_type=pipeline_service.CreateTrainingPipelineRequest, ): @@ -5874,6 +7044,47 @@ def test_get_training_pipeline_rest(request_type): assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED +def test_get_training_pipeline_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_training_pipeline + ] = mock_rpc + + request = {} + client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_training_pipeline_rest_required_fields( request_type=pipeline_service.GetTrainingPipelineRequest, ): @@ -6145,6 +7356,47 @@ def test_list_training_pipelines_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_training_pipelines_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_training_pipelines + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_training_pipelines + ] = mock_rpc + + request = {} + client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_training_pipelines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_training_pipelines_rest_required_fields( request_type=pipeline_service.ListTrainingPipelinesRequest, ): @@ -6489,6 +7741,51 @@ def test_delete_training_pipeline_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_training_pipeline_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_training_pipeline + ] = mock_rpc + + request = {} + client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_training_pipeline_rest_required_fields( request_type=pipeline_service.DeleteTrainingPipelineRequest, ): @@ -6754,6 +8051,47 @@ def test_cancel_training_pipeline_rest(request_type): assert response is None +def test_cancel_training_pipeline_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_training_pipeline + ] = mock_rpc + + request = {} + client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_training_pipeline_rest_required_fields( request_type=pipeline_service.CancelTrainingPipelineRequest, ): @@ -7196,6 +8534,46 @@ def get_message_fields(field): assert response.schedule_name == "schedule_name_value" +def test_create_pipeline_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_pipeline_job + ] = mock_rpc + + request = {} + client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_pipeline_job_rest_required_fields( request_type=pipeline_service.CreatePipelineJobRequest, ): @@ -7492,6 +8870,44 @@ def test_get_pipeline_job_rest(request_type): assert response.schedule_name == "schedule_name_value" +def test_get_pipeline_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_pipeline_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_pipeline_job + ] = mock_rpc + + request = {} + client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_pipeline_job_rest_required_fields( request_type=pipeline_service.GetPipelineJobRequest, ): @@ -7761,6 +9177,46 @@ def test_list_pipeline_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_pipeline_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_pipeline_jobs in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_pipeline_jobs + ] = mock_rpc + + request = {} + client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_pipeline_jobs_rest_required_fields( request_type=pipeline_service.ListPipelineJobsRequest, ): @@ -8101,6 +9557,50 @@ def test_delete_pipeline_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_pipeline_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_pipeline_job + ] = mock_rpc + + request = {} + client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_pipeline_job_rest_required_fields( request_type=pipeline_service.DeletePipelineJobRequest, ): @@ -8362,6 +9862,51 @@ def test_batch_delete_pipeline_jobs_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_delete_pipeline_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_delete_pipeline_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_delete_pipeline_jobs + ] = mock_rpc + + request = {} + client.batch_delete_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_delete_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_delete_pipeline_jobs_rest_required_fields( request_type=pipeline_service.BatchDeletePipelineJobsRequest, ): @@ -8637,6 +10182,46 @@ def test_cancel_pipeline_job_rest(request_type): assert response is None +def test_cancel_pipeline_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_pipeline_job + ] = mock_rpc + + request = {} + client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_pipeline_job_rest_required_fields( request_type=pipeline_service.CancelPipelineJobRequest, ): @@ -8889,6 +10474,51 @@ def test_batch_cancel_pipeline_jobs_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_cancel_pipeline_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_cancel_pipeline_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_cancel_pipeline_jobs + ] = mock_rpc + + request = {} + client.batch_cancel_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_cancel_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_cancel_pipeline_jobs_rest_required_fields( request_type=pipeline_service.BatchCancelPipelineJobsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py index 2cb958fec4..d609bc4805 100644 --- a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py @@ -1227,6 +1227,9 @@ def test_predict_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1250,6 +1253,9 @@ def test_predict_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1258,6 +1264,41 @@ def test_predict_non_empty_request_with_auto_populated_field(): ) +def test_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.predict] = mock_rpc + request = {} + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1284,6 +1325,50 @@ async def test_predict_empty_call_async(): assert args[0] == prediction_service.PredictRequest() +@pytest.mark.asyncio +async def test_predict_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.predict + ] = mock_object + + request = {} + await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_predict_async( transport: str = "grpc_asyncio", request_type=prediction_service.PredictRequest @@ -1471,6 +1556,9 @@ def test_raw_predict_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.raw_predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.raw_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1494,6 +1582,9 @@ def test_raw_predict_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.raw_predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.raw_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1502,6 +1593,41 @@ def test_raw_predict_non_empty_request_with_auto_populated_field(): ) +def test_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.raw_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.raw_predict] = mock_rpc + request = {} + client.raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_raw_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1526,6 +1652,52 @@ async def test_raw_predict_empty_call_async(): assert args[0] == prediction_service.RawPredictRequest() +@pytest.mark.asyncio +async def test_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.raw_predict + ] = mock_object + + request = {} + await client.raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_raw_predict_async( transport: str = "grpc_asyncio", request_type=prediction_service.RawPredictRequest @@ -1768,6 +1940,9 @@ def test_stream_raw_predict_empty_call(): with mock.patch.object( type(client.transport.stream_raw_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stream_raw_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1793,6 +1968,9 @@ def test_stream_raw_predict_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.stream_raw_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stream_raw_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1801,6 +1979,45 @@ def test_stream_raw_predict_non_empty_request_with_auto_populated_field(): ) +def test_stream_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_raw_predict in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_raw_predict + ] = mock_rpc + request = {} + client.stream_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_stream_raw_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1823,6 +2040,52 @@ async def test_stream_raw_predict_empty_call_async(): assert args[0] == prediction_service.StreamRawPredictRequest() +@pytest.mark.asyncio +async def test_stream_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stream_raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stream_raw_predict + ] = mock_object + + request = {} + await client.stream_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stream_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stream_raw_predict_async( transport: str = "grpc_asyncio", @@ -2063,6 +2326,9 @@ def test_direct_predict_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.direct_predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.direct_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2086,6 +2352,9 @@ def test_direct_predict_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.direct_predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.direct_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2094,6 +2363,41 @@ def test_direct_predict_non_empty_request_with_auto_populated_field(): ) +def test_direct_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.direct_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.direct_predict] = mock_rpc + request = {} + client.direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_direct_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2115,6 +2419,52 @@ async def test_direct_predict_empty_call_async(): assert args[0] == prediction_service.DirectPredictRequest() +@pytest.mark.asyncio +async def test_direct_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.direct_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.direct_predict + ] = mock_object + + request = {} + await client.direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_direct_predict_async( transport: str = "grpc_asyncio", @@ -2263,6 +2613,9 @@ def test_direct_raw_predict_empty_call(): with mock.patch.object( type(client.transport.direct_raw_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.direct_raw_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2289,6 +2642,9 @@ def test_direct_raw_predict_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.direct_raw_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.direct_raw_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2298,6 +2654,45 @@ def test_direct_raw_predict_non_empty_request_with_auto_populated_field(): ) +def test_direct_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.direct_raw_predict in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.direct_raw_predict + ] = mock_rpc + request = {} + client.direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_direct_raw_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2323,6 +2718,52 @@ async def test_direct_raw_predict_empty_call_async(): assert args[0] == prediction_service.DirectRawPredictRequest() +@pytest.mark.asyncio +async def test_direct_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.direct_raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.direct_raw_predict + ] = mock_object + + request = {} + await client.direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_direct_raw_predict_async( transport: str = "grpc_asyncio", @@ -2466,6 +2907,92 @@ def test_stream_direct_predict(request_type, transport: str = "grpc"): assert isinstance(message, prediction_service.StreamDirectPredictResponse) +def test_stream_direct_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_direct_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_direct_predict + ] = mock_rpc + request = [{}] + client.stream_direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_stream_direct_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stream_direct_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stream_direct_predict + ] = mock_object + + request = [{}] + await client.stream_direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stream_direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stream_direct_predict_async( transport: str = "grpc_asyncio", @@ -2543,6 +3070,92 @@ def test_stream_direct_raw_predict(request_type, transport: str = "grpc"): assert isinstance(message, prediction_service.StreamDirectRawPredictResponse) +def test_stream_direct_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_direct_raw_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_direct_raw_predict + ] = mock_rpc + request = [{}] + client.stream_direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_stream_direct_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stream_direct_raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stream_direct_raw_predict + ] = mock_object + + request = [{}] + await client.stream_direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stream_direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stream_direct_raw_predict_async( transport: str = "grpc_asyncio", @@ -2620,6 +3233,89 @@ def test_streaming_predict(request_type, transport: str = "grpc"): assert isinstance(message, prediction_service.StreamingPredictResponse) +def test_streaming_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.streaming_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_predict + ] = mock_rpc + request = [{}] + client.streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_streaming_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.streaming_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.streaming_predict + ] = mock_object + + request = [{}] + await client.streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_predict_async( transport: str = "grpc_asyncio", @@ -2709,6 +3405,9 @@ def test_server_streaming_predict_empty_call(): with mock.patch.object( type(client.transport.server_streaming_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.server_streaming_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2734,6 +3433,9 @@ def test_server_streaming_predict_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.server_streaming_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.server_streaming_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2742,6 +3444,46 @@ def test_server_streaming_predict_non_empty_request_with_auto_populated_field(): ) +def test_server_streaming_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.server_streaming_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.server_streaming_predict + ] = mock_rpc + request = {} + client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.server_streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_server_streaming_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2766,6 +3508,52 @@ async def test_server_streaming_predict_empty_call_async(): assert args[0] == prediction_service.StreamingPredictRequest() +@pytest.mark.asyncio +async def test_server_streaming_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.server_streaming_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.server_streaming_predict + ] = mock_object + + request = {} + await client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.server_streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_server_streaming_predict_async( transport: str = "grpc_asyncio", @@ -2909,6 +3697,92 @@ def test_streaming_raw_predict(request_type, transport: str = "grpc"): assert isinstance(message, prediction_service.StreamingRawPredictResponse) +def test_streaming_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.streaming_raw_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_raw_predict + ] = mock_rpc + request = [{}] + client.streaming_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_streaming_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.streaming_raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.streaming_raw_predict + ] = mock_object + + request = [{}] + await client.streaming_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.streaming_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_raw_predict_async( transport: str = "grpc_asyncio", @@ -2996,6 +3870,9 @@ def test_explain_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.explain), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.explain() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3020,6 +3897,9 @@ def test_explain_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.explain), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.explain(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3029,6 +3909,41 @@ def test_explain_non_empty_request_with_auto_populated_field(): ) +def test_explain_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.explain in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.explain] = mock_rpc + request = {} + client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.explain(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_explain_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3052,6 +3967,50 @@ async def test_explain_empty_call_async(): assert args[0] == prediction_service.ExplainRequest() +@pytest.mark.asyncio +async def test_explain_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.explain + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.explain + ] = mock_object + + request = {} + await client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.explain(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_explain_async( transport: str = "grpc_asyncio", request_type=prediction_service.ExplainRequest @@ -3230,6 +4189,9 @@ def test_generate_content_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.generate_content() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3253,6 +4215,9 @@ def test_generate_content_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.generate_content(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3261,6 +4226,43 @@ def test_generate_content_non_empty_request_with_auto_populated_field(): ) +def test_generate_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_content in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.generate_content + ] = mock_rpc + request = {} + client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_generate_content_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3282,6 +4284,52 @@ async def test_generate_content_empty_call_async(): assert args[0] == prediction_service.GenerateContentRequest() +@pytest.mark.asyncio +async def test_generate_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.generate_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.generate_content + ] = mock_object + + request = {} + await client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_generate_content_async( transport: str = "grpc_asyncio", @@ -3520,6 +4568,9 @@ def test_stream_generate_content_empty_call(): with mock.patch.object( type(client.transport.stream_generate_content), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stream_generate_content() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3545,6 +4596,9 @@ def test_stream_generate_content_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.stream_generate_content), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stream_generate_content(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3553,6 +4607,46 @@ def test_stream_generate_content_non_empty_request_with_auto_populated_field(): ) +def test_stream_generate_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_generate_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_generate_content + ] = mock_rpc + request = {} + client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_stream_generate_content_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3577,6 +4671,52 @@ async def test_stream_generate_content_empty_call_async(): assert args[0] == prediction_service.GenerateContentRequest() +@pytest.mark.asyncio +async def test_stream_generate_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stream_generate_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stream_generate_content + ] = mock_object + + request = {} + await client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stream_generate_content_async( transport: str = "grpc_asyncio", @@ -3824,6 +4964,42 @@ def test_predict_rest(request_type): assert response.model_display_name == "model_display_name_value" +def test_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.predict] = mock_rpc + + request = {} + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_predict_rest_required_fields(request_type=prediction_service.PredictRequest): transport_class = transports.PredictionServiceRestTransport @@ -4104,6 +5280,42 @@ def test_raw_predict_rest(request_type): assert response.data == b"data_blob" +def test_raw_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.raw_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.raw_predict] = mock_rpc + + request = {} + client.raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_raw_predict_rest_required_fields( request_type=prediction_service.RawPredictRequest, ): @@ -4377,6 +5589,46 @@ def test_stream_raw_predict_rest(request_type): assert response.data == b"data_blob" +def test_stream_raw_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_raw_predict in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_raw_predict + ] = mock_rpc + + request = {} + client.stream_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_stream_raw_predict_rest_required_fields( request_type=prediction_service.StreamRawPredictRequest, ): @@ -4647,6 +5899,42 @@ def test_direct_predict_rest(request_type): assert isinstance(response, prediction_service.DirectPredictResponse) +def test_direct_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.direct_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.direct_predict] = mock_rpc + + request = {} + client.direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_direct_predict_rest_required_fields( request_type=prediction_service.DirectPredictRequest, ): @@ -4858,6 +6146,46 @@ def test_direct_raw_predict_rest(request_type): assert response.output == b"output_blob" +def test_direct_raw_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.direct_raw_predict in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.direct_raw_predict + ] = mock_rpc + + request = {} + client.direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_direct_raw_predict_rest_required_fields( request_type=prediction_service.DirectRawPredictRequest, ): @@ -5106,6 +6434,47 @@ def test_server_streaming_predict_rest(request_type): assert isinstance(response, prediction_service.StreamingPredictResponse) +def test_server_streaming_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.server_streaming_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.server_streaming_predict + ] = mock_rpc + + request = {} + client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.server_streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_server_streaming_predict_rest_required_fields( request_type=prediction_service.StreamingPredictRequest, ): @@ -5332,6 +6701,42 @@ def test_explain_rest(request_type): assert response.deployed_model_id == "deployed_model_id_value" +def test_explain_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.explain in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.explain] = mock_rpc + + request = {} + client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.explain(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_explain_rest_required_fields(request_type=prediction_service.ExplainRequest): transport_class = transports.PredictionServiceRestTransport @@ -5611,6 +7016,44 @@ def test_generate_content_rest(request_type): assert isinstance(response, prediction_service.GenerateContentResponse) +def test_generate_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_content in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.generate_content + ] = mock_rpc + + request = {} + client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_generate_content_rest_required_fields( request_type=prediction_service.GenerateContentRequest, ): @@ -5895,6 +7338,47 @@ def test_stream_generate_content_rest(request_type): assert isinstance(response, prediction_service.GenerateContentResponse) +def test_stream_generate_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_generate_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_generate_content + ] = mock_rpc + + request = {} + client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_stream_generate_content_rest_required_fields( request_type=prediction_service.GenerateContentRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py index 76a08bee81..e2012b2f86 100644 --- a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py @@ -1233,6 +1233,9 @@ def test_create_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1256,6 +1259,9 @@ def test_create_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1264,6 +1270,41 @@ def test_create_schedule_non_empty_request_with_auto_populated_field(): ) +def test_create_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_schedule] = mock_rpc + request = {} + client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1294,6 +1335,52 @@ async def test_create_schedule_empty_call_async(): assert args[0] == schedule_service.CreateScheduleRequest() +@pytest.mark.asyncio +async def test_create_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_schedule + ] = mock_object + + request = {} + await client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.CreateScheduleRequest @@ -1543,6 +1630,9 @@ def test_delete_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1566,6 +1656,9 @@ def test_delete_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1574,6 +1667,45 @@ def test_delete_schedule_non_empty_request_with_auto_populated_field(): ) +def test_delete_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_schedule] = mock_rpc + request = {} + client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1595,6 +1727,56 @@ async def test_delete_schedule_empty_call_async(): assert args[0] == schedule_service.DeleteScheduleRequest() +@pytest.mark.asyncio +async def test_delete_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_schedule + ] = mock_object + + request = {} + await client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.DeleteScheduleRequest @@ -1835,6 +2017,9 @@ def test_get_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1858,6 +2043,9 @@ def test_get_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1866,6 +2054,41 @@ def test_get_schedule_non_empty_request_with_auto_populated_field(): ) +def test_get_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_schedule] = mock_rpc + request = {} + client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1896,6 +2119,52 @@ async def test_get_schedule_empty_call_async(): assert args[0] == schedule_service.GetScheduleRequest() +@pytest.mark.asyncio +async def test_get_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_schedule + ] = mock_object + + request = {} + await client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.GetScheduleRequest @@ -2134,6 +2403,9 @@ def test_list_schedules_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_schedules() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2160,6 +2432,9 @@ def test_list_schedules_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_schedules(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2171,6 +2446,41 @@ def test_list_schedules_non_empty_request_with_auto_populated_field(): ) +def test_list_schedules_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_schedules in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_schedules] = mock_rpc + request = {} + client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_schedules(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_schedules_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2194,6 +2504,52 @@ async def test_list_schedules_empty_call_async(): assert args[0] == schedule_service.ListSchedulesRequest() +@pytest.mark.asyncio +async def test_list_schedules_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_schedules + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_schedules + ] = mock_object + + request = {} + await client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_schedules(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_schedules_async( transport: str = "grpc_asyncio", request_type=schedule_service.ListSchedulesRequest @@ -2609,6 +2965,9 @@ def test_pause_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.pause_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2632,6 +2991,9 @@ def test_pause_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.pause_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2640,6 +3002,41 @@ def test_pause_schedule_non_empty_request_with_auto_populated_field(): ) +def test_pause_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.pause_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.pause_schedule] = mock_rpc + request = {} + client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.pause_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_pause_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2659,6 +3056,52 @@ async def test_pause_schedule_empty_call_async(): assert args[0] == schedule_service.PauseScheduleRequest() +@pytest.mark.asyncio +async def test_pause_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.pause_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.pause_schedule + ] = mock_object + + request = {} + await client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.pause_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_pause_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.PauseScheduleRequest @@ -2875,6 +3318,9 @@ def test_resume_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.resume_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2898,6 +3344,9 @@ def test_resume_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.resume_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2906,6 +3355,41 @@ def test_resume_schedule_non_empty_request_with_auto_populated_field(): ) +def test_resume_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.resume_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.resume_schedule] = mock_rpc + request = {} + client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.resume_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_resume_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2925,6 +3409,52 @@ async def test_resume_schedule_empty_call_async(): assert args[0] == schedule_service.ResumeScheduleRequest() +@pytest.mark.asyncio +async def test_resume_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.resume_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.resume_schedule + ] = mock_object + + request = {} + await client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.resume_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_resume_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.ResumeScheduleRequest @@ -3169,6 +3699,9 @@ def test_update_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3190,12 +3723,50 @@ def test_update_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == schedule_service.UpdateScheduleRequest() +def test_update_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_schedule] = mock_rpc + request = {} + client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3226,6 +3797,52 @@ async def test_update_schedule_empty_call_async(): assert args[0] == schedule_service.UpdateScheduleRequest() +@pytest.mark.asyncio +async def test_update_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_schedule + ] = mock_object + + request = {} + await client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.UpdateScheduleRequest @@ -3688,6 +4305,42 @@ def get_message_fields(field): assert response.catch_up is True +def test_create_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_schedule] = mock_rpc + + request = {} + client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_schedule_rest_required_fields( request_type=schedule_service.CreateScheduleRequest, ): @@ -3960,6 +4613,46 @@ def test_delete_schedule_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_schedule] = mock_rpc + + request = {} + client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_schedule_rest_required_fields( request_type=schedule_service.DeleteScheduleRequest, ): @@ -4240,6 +4933,42 @@ def test_get_schedule_rest(request_type): assert response.catch_up is True +def test_get_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_schedule] = mock_rpc + + request = {} + client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_schedule_rest_required_fields( request_type=schedule_service.GetScheduleRequest, ): @@ -4506,6 +5235,42 @@ def test_list_schedules_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_schedules_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_schedules in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_schedules] = mock_rpc + + request = {} + client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_schedules(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_schedules_rest_required_fields( request_type=schedule_service.ListSchedulesRequest, ): @@ -4843,6 +5608,42 @@ def test_pause_schedule_rest(request_type): assert response is None +def test_pause_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.pause_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.pause_schedule] = mock_rpc + + request = {} + client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.pause_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_pause_schedule_rest_required_fields( request_type=schedule_service.PauseScheduleRequest, ): @@ -5095,6 +5896,42 @@ def test_resume_schedule_rest(request_type): assert response is None +def test_resume_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.resume_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.resume_schedule] = mock_rpc + + request = {} + client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.resume_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_resume_schedule_rest_required_fields( request_type=schedule_service.ResumeScheduleRequest, ): @@ -5572,6 +6409,42 @@ def get_message_fields(field): assert response.catch_up is True +def test_update_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_schedule] = mock_rpc + + request = {} + client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_schedule_rest_required_fields( request_type=schedule_service.UpdateScheduleRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py index 764b32127a..bc57eb5d36 100644 --- a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py @@ -1254,6 +1254,9 @@ def test_create_specialist_pool_empty_call(): with mock.patch.object( type(client.transport.create_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1279,6 +1282,9 @@ def test_create_specialist_pool_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_specialist_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1287,6 +1293,50 @@ def test_create_specialist_pool_non_empty_request_with_auto_populated_field(): ) +def test_create_specialist_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_specialist_pool + ] = mock_rpc + request = {} + client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_specialist_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1310,6 +1360,56 @@ async def test_create_specialist_pool_empty_call_async(): assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() +@pytest.mark.asyncio +async def test_create_specialist_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_specialist_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_specialist_pool + ] = mock_object + + request = {} + await client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_specialist_pool_async( transport: str = "grpc_asyncio", @@ -1570,6 +1670,9 @@ def test_get_specialist_pool_empty_call(): with mock.patch.object( type(client.transport.get_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1595,6 +1698,9 @@ def test_get_specialist_pool_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_specialist_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1603,6 +1709,45 @@ def test_get_specialist_pool_non_empty_request_with_auto_populated_field(): ) +def test_get_specialist_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_specialist_pool in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_specialist_pool + ] = mock_rpc + request = {} + client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_specialist_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1633,6 +1778,52 @@ async def test_get_specialist_pool_empty_call_async(): assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() +@pytest.mark.asyncio +async def test_get_specialist_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_specialist_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_specialist_pool + ] = mock_object + + request = {} + await client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_specialist_pool_async( transport: str = "grpc_asyncio", @@ -1886,6 +2077,9 @@ def test_list_specialist_pools_empty_call(): with mock.patch.object( type(client.transport.list_specialist_pools), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_specialist_pools() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1912,6 +2106,9 @@ def test_list_specialist_pools_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_specialist_pools), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_specialist_pools(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1921,6 +2118,46 @@ def test_list_specialist_pools_non_empty_request_with_auto_populated_field(): ) +def test_list_specialist_pools_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_specialist_pools + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_specialist_pools + ] = mock_rpc + request = {} + client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_specialist_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_specialist_pools_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1946,6 +2183,52 @@ async def test_list_specialist_pools_empty_call_async(): assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() +@pytest.mark.asyncio +async def test_list_specialist_pools_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_specialist_pools + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_specialist_pools + ] = mock_object + + request = {} + await client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_specialist_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_specialist_pools_async( transport: str = "grpc_asyncio", @@ -2384,6 +2667,9 @@ def test_delete_specialist_pool_empty_call(): with mock.patch.object( type(client.transport.delete_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2409,6 +2695,9 @@ def test_delete_specialist_pool_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_specialist_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2417,6 +2706,50 @@ def test_delete_specialist_pool_non_empty_request_with_auto_populated_field(): ) +def test_delete_specialist_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_specialist_pool + ] = mock_rpc + request = {} + client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_specialist_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2440,6 +2773,56 @@ async def test_delete_specialist_pool_empty_call_async(): assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() +@pytest.mark.asyncio +async def test_delete_specialist_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_specialist_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_specialist_pool + ] = mock_object + + request = {} + await client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_specialist_pool_async( transport: str = "grpc_asyncio", @@ -2677,6 +3060,9 @@ def test_update_specialist_pool_empty_call(): with mock.patch.object( type(client.transport.update_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2700,12 +3086,59 @@ def test_update_specialist_pool_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_specialist_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() +def test_update_specialist_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_specialist_pool + ] = mock_rpc + request = {} + client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_specialist_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2729,6 +3162,56 @@ async def test_update_specialist_pool_empty_call_async(): assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() +@pytest.mark.asyncio +async def test_update_specialist_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_specialist_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_specialist_pool + ] = mock_object + + request = {} + await client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_specialist_pool_async( transport: str = "grpc_asyncio", @@ -3050,6 +3533,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_specialist_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_specialist_pool + ] = mock_rpc + + request = {} + client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_specialist_pool_rest_required_fields( request_type=specialist_pool_service.CreateSpecialistPoolRequest, ): @@ -3338,6 +3866,46 @@ def test_get_specialist_pool_rest(request_type): assert response.specialist_worker_emails == ["specialist_worker_emails_value"] +def test_get_specialist_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_specialist_pool in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_specialist_pool + ] = mock_rpc + + request = {} + client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_specialist_pool_rest_required_fields( request_type=specialist_pool_service.GetSpecialistPoolRequest, ): @@ -3612,6 +4180,47 @@ def test_list_specialist_pools_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_specialist_pools_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_specialist_pools + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_specialist_pools + ] = mock_rpc + + request = {} + client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_specialist_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_specialist_pools_rest_required_fields( request_type=specialist_pool_service.ListSpecialistPoolsRequest, ): @@ -3958,6 +4567,51 @@ def test_delete_specialist_pool_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_specialist_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_specialist_pool + ] = mock_rpc + + request = {} + client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_specialist_pool_rest_required_fields( request_type=specialist_pool_service.DeleteSpecialistPoolRequest, ): @@ -4314,6 +4968,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_specialist_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_specialist_pool + ] = mock_rpc + + request = {} + client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_specialist_pool_rest_required_fields( request_type=specialist_pool_service.UpdateSpecialistPoolRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py index f1fb2edfbf..2a251348bd 100644 --- a/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py @@ -1242,6 +1242,9 @@ def test_create_tensorboard_empty_call(): with mock.patch.object( type(client.transport.create_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1267,6 +1270,9 @@ def test_create_tensorboard_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1275,6 +1281,49 @@ def test_create_tensorboard_non_empty_request_with_auto_populated_field(): ) +def test_create_tensorboard_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard + ] = mock_rpc + request = {} + client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1298,6 +1347,56 @@ async def test_create_tensorboard_empty_call_async(): assert args[0] == tensorboard_service.CreateTensorboardRequest() +@pytest.mark.asyncio +async def test_create_tensorboard_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tensorboard + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tensorboard + ] = mock_object + + request = {} + await client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_async( transport: str = "grpc_asyncio", @@ -1556,6 +1655,9 @@ def test_get_tensorboard_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1579,6 +1681,9 @@ def test_get_tensorboard_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1587,6 +1692,41 @@ def test_get_tensorboard_non_empty_request_with_auto_populated_field(): ) +def test_get_tensorboard_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_tensorboard in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_tensorboard] = mock_rpc + request = {} + client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1616,6 +1756,52 @@ async def test_get_tensorboard_empty_call_async(): assert args[0] == tensorboard_service.GetTensorboardRequest() +@pytest.mark.asyncio +async def test_get_tensorboard_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tensorboard + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tensorboard + ] = mock_object + + request = {} + await client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_async( transport: str = "grpc_asyncio", @@ -1858,6 +2044,9 @@ def test_update_tensorboard_empty_call(): with mock.patch.object( type(client.transport.update_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1881,12 +2070,58 @@ def test_update_tensorboard_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == tensorboard_service.UpdateTensorboardRequest() +def test_update_tensorboard_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard + ] = mock_rpc + request = {} + client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1910,6 +2145,56 @@ async def test_update_tensorboard_empty_call_async(): assert args[0] == tensorboard_service.UpdateTensorboardRequest() +@pytest.mark.asyncio +async def test_update_tensorboard_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tensorboard + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tensorboard + ] = mock_object + + request = {} + await client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_async( transport: str = "grpc_asyncio", @@ -2160,6 +2445,9 @@ def test_list_tensorboards_empty_call(): with mock.patch.object( type(client.transport.list_tensorboards), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboards() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2188,6 +2476,9 @@ def test_list_tensorboards_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_tensorboards), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboards(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2199,6 +2490,43 @@ def test_list_tensorboards_non_empty_request_with_auto_populated_field(): ) +def test_list_tensorboards_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_tensorboards in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboards + ] = mock_rpc + request = {} + client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboards(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboards_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2224,6 +2552,52 @@ async def test_list_tensorboards_empty_call_async(): assert args[0] == tensorboard_service.ListTensorboardsRequest() +@pytest.mark.asyncio +async def test_list_tensorboards_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tensorboards + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tensorboards + ] = mock_object + + request = {} + await client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tensorboards(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboards_async( transport: str = "grpc_asyncio", @@ -2662,6 +3036,9 @@ def test_delete_tensorboard_empty_call(): with mock.patch.object( type(client.transport.delete_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2687,6 +3064,9 @@ def test_delete_tensorboard_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2695,6 +3075,49 @@ def test_delete_tensorboard_non_empty_request_with_auto_populated_field(): ) +def test_delete_tensorboard_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard + ] = mock_rpc + request = {} + client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2718,6 +3141,56 @@ async def test_delete_tensorboard_empty_call_async(): assert args[0] == tensorboard_service.DeleteTensorboardRequest() +@pytest.mark.asyncio +async def test_delete_tensorboard_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tensorboard + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tensorboard + ] = mock_object + + request = {} + await client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_async( transport: str = "grpc_asyncio", @@ -2955,6 +3428,9 @@ def test_read_tensorboard_usage_empty_call(): with mock.patch.object( type(client.transport.read_tensorboard_usage), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_usage() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2980,6 +3456,9 @@ def test_read_tensorboard_usage_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.read_tensorboard_usage), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_usage(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2988,6 +3467,46 @@ def test_read_tensorboard_usage_non_empty_request_with_auto_populated_field(): ) +def test_read_tensorboard_usage_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_usage + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_usage + ] = mock_rpc + request = {} + client.read_tensorboard_usage(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_usage(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_usage_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3011,6 +3530,52 @@ async def test_read_tensorboard_usage_empty_call_async(): assert args[0] == tensorboard_service.ReadTensorboardUsageRequest() +@pytest.mark.asyncio +async def test_read_tensorboard_usage_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_tensorboard_usage + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_tensorboard_usage + ] = mock_object + + request = {} + await client.read_tensorboard_usage(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_tensorboard_usage(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_usage_async( transport: str = "grpc_asyncio", @@ -3251,6 +3816,9 @@ def test_read_tensorboard_size_empty_call(): with mock.patch.object( type(client.transport.read_tensorboard_size), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_size() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3276,6 +3844,9 @@ def test_read_tensorboard_size_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.read_tensorboard_size), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_size(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3284,6 +3855,46 @@ def test_read_tensorboard_size_non_empty_request_with_auto_populated_field(): ) +def test_read_tensorboard_size_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_size + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_size + ] = mock_rpc + request = {} + client.read_tensorboard_size(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_size(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_size_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3309,6 +3920,52 @@ async def test_read_tensorboard_size_empty_call_async(): assert args[0] == tensorboard_service.ReadTensorboardSizeRequest() +@pytest.mark.asyncio +async def test_read_tensorboard_size_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_tensorboard_size + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_tensorboard_size + ] = mock_object + + request = {} + await client.read_tensorboard_size(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_tensorboard_size(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_size_async( transport: str = "grpc_asyncio", @@ -3560,6 +4217,9 @@ def test_create_tensorboard_experiment_empty_call(): with mock.patch.object( type(client.transport.create_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_experiment() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3586,6 +4246,9 @@ def test_create_tensorboard_experiment_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.create_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_experiment(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3595,6 +4258,46 @@ def test_create_tensorboard_experiment_non_empty_request_with_auto_populated_fie ) +def test_create_tensorboard_experiment_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_experiment + ] = mock_rpc + request = {} + client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_experiment_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3624,6 +4327,52 @@ async def test_create_tensorboard_experiment_empty_call_async(): assert args[0] == tensorboard_service.CreateTensorboardExperimentRequest() +@pytest.mark.asyncio +async def test_create_tensorboard_experiment_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tensorboard_experiment + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tensorboard_experiment + ] = mock_object + + request = {} + await client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_experiment_async( transport: str = "grpc_asyncio", @@ -3911,6 +4660,9 @@ def test_get_tensorboard_experiment_empty_call(): with mock.patch.object( type(client.transport.get_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_experiment() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3936,6 +4688,9 @@ def test_get_tensorboard_experiment_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.get_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_experiment(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3944,6 +4699,46 @@ def test_get_tensorboard_experiment_non_empty_request_with_auto_populated_field( ) +def test_get_tensorboard_experiment_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_experiment + ] = mock_rpc + request = {} + client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_experiment_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3973,6 +4768,52 @@ async def test_get_tensorboard_experiment_empty_call_async(): assert args[0] == tensorboard_service.GetTensorboardExperimentRequest() +@pytest.mark.asyncio +async def test_get_tensorboard_experiment_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tensorboard_experiment + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tensorboard_experiment + ] = mock_object + + request = {} + await client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_experiment_async( transport: str = "grpc_asyncio", @@ -4232,6 +5073,9 @@ def test_update_tensorboard_experiment_empty_call(): with mock.patch.object( type(client.transport.update_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_experiment() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4255,12 +5099,55 @@ def test_update_tensorboard_experiment_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.update_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_experiment(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == tensorboard_service.UpdateTensorboardExperimentRequest() +def test_update_tensorboard_experiment_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_experiment + ] = mock_rpc + request = {} + client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_experiment_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4290,6 +5177,52 @@ async def test_update_tensorboard_experiment_empty_call_async(): assert args[0] == tensorboard_service.UpdateTensorboardExperimentRequest() +@pytest.mark.asyncio +async def test_update_tensorboard_experiment_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tensorboard_experiment + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tensorboard_experiment + ] = mock_object + + request = {} + await client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_experiment_async( transport: str = "grpc_asyncio", @@ -4559,6 +5492,9 @@ def test_list_tensorboard_experiments_empty_call(): with mock.patch.object( type(client.transport.list_tensorboard_experiments), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_experiments() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4587,6 +5523,9 @@ def test_list_tensorboard_experiments_non_empty_request_with_auto_populated_fiel with mock.patch.object( type(client.transport.list_tensorboard_experiments), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_experiments(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4598,6 +5537,46 @@ def test_list_tensorboard_experiments_non_empty_request_with_auto_populated_fiel ) +def test_list_tensorboard_experiments_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_experiments + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_experiments + ] = mock_rpc + request = {} + client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_experiments(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_experiments_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4623,6 +5602,52 @@ async def test_list_tensorboard_experiments_empty_call_async(): assert args[0] == tensorboard_service.ListTensorboardExperimentsRequest() +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tensorboard_experiments + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tensorboard_experiments + ] = mock_object + + request = {} + await client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tensorboard_experiments(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_experiments_async( transport: str = "grpc_asyncio", @@ -5066,6 +6091,9 @@ def test_delete_tensorboard_experiment_empty_call(): with mock.patch.object( type(client.transport.delete_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_experiment() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5091,6 +6119,9 @@ def test_delete_tensorboard_experiment_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.delete_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_experiment(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5099,6 +6130,50 @@ def test_delete_tensorboard_experiment_non_empty_request_with_auto_populated_fie ) +def test_delete_tensorboard_experiment_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_experiment + ] = mock_rpc + request = {} + client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_experiment_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5116,10 +6191,60 @@ async def test_delete_tensorboard_experiment_empty_call_async(): call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") ) - response = await client.delete_tensorboard_experiment() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == tensorboard_service.DeleteTensorboardExperimentRequest() + response = await client.delete_tensorboard_experiment() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == tensorboard_service.DeleteTensorboardExperimentRequest() + + +@pytest.mark.asyncio +async def test_delete_tensorboard_experiment_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tensorboard_experiment + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tensorboard_experiment + ] = mock_object + + request = {} + await client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio @@ -5368,6 +6493,9 @@ def test_create_tensorboard_run_empty_call(): with mock.patch.object( type(client.transport.create_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_run() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5394,6 +6522,9 @@ def test_create_tensorboard_run_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_run(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5403,6 +6534,46 @@ def test_create_tensorboard_run_non_empty_request_with_auto_populated_field(): ) +def test_create_tensorboard_run_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_run + ] = mock_rpc + request = {} + client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_run_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5431,6 +6602,52 @@ async def test_create_tensorboard_run_empty_call_async(): assert args[0] == tensorboard_service.CreateTensorboardRunRequest() +@pytest.mark.asyncio +async def test_create_tensorboard_run_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tensorboard_run + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tensorboard_run + ] = mock_object + + request = {} + await client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_run_async( transport: str = "grpc_asyncio", @@ -5697,6 +6914,9 @@ def test_batch_create_tensorboard_runs_empty_call(): with mock.patch.object( type(client.transport.batch_create_tensorboard_runs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_tensorboard_runs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5722,6 +6942,9 @@ def test_batch_create_tensorboard_runs_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.batch_create_tensorboard_runs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_tensorboard_runs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5730,6 +6953,46 @@ def test_batch_create_tensorboard_runs_non_empty_request_with_auto_populated_fie ) +def test_batch_create_tensorboard_runs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_tensorboard_runs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_tensorboard_runs + ] = mock_rpc + request = {} + client.batch_create_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_tensorboard_runs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5753,6 +7016,52 @@ async def test_batch_create_tensorboard_runs_empty_call_async(): assert args[0] == tensorboard_service.BatchCreateTensorboardRunsRequest() +@pytest.mark.asyncio +async def test_batch_create_tensorboard_runs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_create_tensorboard_runs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_create_tensorboard_runs + ] = mock_object + + request = {} + await client.batch_create_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_create_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_tensorboard_runs_async( transport: str = "grpc_asyncio", @@ -6021,6 +7330,9 @@ def test_get_tensorboard_run_empty_call(): with mock.patch.object( type(client.transport.get_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_run() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6046,6 +7358,9 @@ def test_get_tensorboard_run_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_run(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6054,6 +7369,45 @@ def test_get_tensorboard_run_non_empty_request_with_auto_populated_field(): ) +def test_get_tensorboard_run_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_run in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_run + ] = mock_rpc + request = {} + client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_run_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6082,6 +7436,52 @@ async def test_get_tensorboard_run_empty_call_async(): assert args[0] == tensorboard_service.GetTensorboardRunRequest() +@pytest.mark.asyncio +async def test_get_tensorboard_run_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tensorboard_run + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tensorboard_run + ] = mock_object + + request = {} + await client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_run_async( transport: str = "grpc_asyncio", @@ -6337,6 +7737,9 @@ def test_update_tensorboard_run_empty_call(): with mock.patch.object( type(client.transport.update_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_run() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6360,12 +7763,55 @@ def test_update_tensorboard_run_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_run(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == tensorboard_service.UpdateTensorboardRunRequest() +def test_update_tensorboard_run_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_run + ] = mock_rpc + request = {} + client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_run_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6394,6 +7840,52 @@ async def test_update_tensorboard_run_empty_call_async(): assert args[0] == tensorboard_service.UpdateTensorboardRunRequest() +@pytest.mark.asyncio +async def test_update_tensorboard_run_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tensorboard_run + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tensorboard_run + ] = mock_object + + request = {} + await client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_run_async( transport: str = "grpc_asyncio", @@ -6653,6 +8145,9 @@ def test_list_tensorboard_runs_empty_call(): with mock.patch.object( type(client.transport.list_tensorboard_runs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_runs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6681,6 +8176,9 @@ def test_list_tensorboard_runs_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_tensorboard_runs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_runs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6692,6 +8190,46 @@ def test_list_tensorboard_runs_non_empty_request_with_auto_populated_field(): ) +def test_list_tensorboard_runs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_runs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_runs + ] = mock_rpc + request = {} + client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_runs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6717,6 +8255,52 @@ async def test_list_tensorboard_runs_empty_call_async(): assert args[0] == tensorboard_service.ListTensorboardRunsRequest() +@pytest.mark.asyncio +async def test_list_tensorboard_runs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tensorboard_runs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tensorboard_runs + ] = mock_object + + request = {} + await client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_runs_async( transport: str = "grpc_asyncio", @@ -7155,6 +8739,9 @@ def test_delete_tensorboard_run_empty_call(): with mock.patch.object( type(client.transport.delete_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_run() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7180,6 +8767,9 @@ def test_delete_tensorboard_run_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_run(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7188,6 +8778,50 @@ def test_delete_tensorboard_run_non_empty_request_with_auto_populated_field(): ) +def test_delete_tensorboard_run_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_run + ] = mock_rpc + request = {} + client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_run_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7211,6 +8845,56 @@ async def test_delete_tensorboard_run_empty_call_async(): assert args[0] == tensorboard_service.DeleteTensorboardRunRequest() +@pytest.mark.asyncio +async def test_delete_tensorboard_run_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tensorboard_run + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tensorboard_run + ] = mock_object + + request = {} + await client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_run_async( transport: str = "grpc_asyncio", @@ -7452,6 +9136,9 @@ def test_batch_create_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.batch_create_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7477,6 +9164,9 @@ def test_batch_create_tensorboard_time_series_non_empty_request_with_auto_popula with mock.patch.object( type(client.transport.batch_create_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7485,6 +9175,46 @@ def test_batch_create_tensorboard_time_series_non_empty_request_with_auto_popula ) +def test_batch_create_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_tensorboard_time_series + ] = mock_rpc + request = {} + client.batch_create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7508,6 +9238,52 @@ async def test_batch_create_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.BatchCreateTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_batch_create_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_create_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_create_tensorboard_time_series + ] = mock_object + + request = {} + await client.batch_create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -7805,6 +9581,9 @@ def test_create_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.create_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7831,6 +9610,9 @@ def test_create_tensorboard_time_series_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.create_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7840,6 +9622,46 @@ def test_create_tensorboard_time_series_non_empty_request_with_auto_populated_fi ) +def test_create_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_time_series + ] = mock_rpc + request = {} + client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7871,6 +9693,52 @@ async def test_create_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.CreateTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_create_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tensorboard_time_series + ] = mock_object + + request = {} + await client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -8162,6 +10030,9 @@ def test_get_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.get_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8187,6 +10058,9 @@ def test_get_tensorboard_time_series_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.get_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8195,6 +10069,46 @@ def test_get_tensorboard_time_series_non_empty_request_with_auto_populated_field ) +def test_get_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_time_series + ] = mock_rpc + request = {} + client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8226,6 +10140,52 @@ async def test_get_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.GetTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_get_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tensorboard_time_series + ] = mock_object + + request = {} + await client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -8499,6 +10459,9 @@ def test_update_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.update_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8522,12 +10485,55 @@ def test_update_tensorboard_time_series_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.update_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == tensorboard_service.UpdateTensorboardTimeSeriesRequest() +def test_update_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_time_series + ] = mock_rpc + request = {} + client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8559,6 +10565,52 @@ async def test_update_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.UpdateTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_update_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tensorboard_time_series + ] = mock_object + + request = {} + await client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -8835,6 +10887,9 @@ def test_list_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.list_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8863,6 +10918,9 @@ def test_list_tensorboard_time_series_non_empty_request_with_auto_populated_fiel with mock.patch.object( type(client.transport.list_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8874,6 +10932,46 @@ def test_list_tensorboard_time_series_non_empty_request_with_auto_populated_fiel ) +def test_list_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_time_series + ] = mock_rpc + request = {} + client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8899,6 +10997,52 @@ async def test_list_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.ListTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tensorboard_time_series + ] = mock_object + + request = {} + await client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -9343,6 +11487,9 @@ def test_delete_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.delete_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9368,6 +11515,9 @@ def test_delete_tensorboard_time_series_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.delete_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9376,6 +11526,50 @@ def test_delete_tensorboard_time_series_non_empty_request_with_auto_populated_fi ) +def test_delete_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_time_series + ] = mock_rpc + request = {} + client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9399,6 +11593,56 @@ async def test_delete_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.DeleteTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_delete_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tensorboard_time_series + ] = mock_object + + request = {} + await client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -9640,6 +11884,9 @@ def test_batch_read_tensorboard_time_series_data_empty_call(): with mock.patch.object( type(client.transport.batch_read_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_read_tensorboard_time_series_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9667,6 +11914,9 @@ def test_batch_read_tensorboard_time_series_data_non_empty_request_with_auto_pop with mock.patch.object( type(client.transport.batch_read_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_read_tensorboard_time_series_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9675,6 +11925,46 @@ def test_batch_read_tensorboard_time_series_data_non_empty_request_with_auto_pop ) +def test_batch_read_tensorboard_time_series_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_read_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_read_tensorboard_time_series_data + ] = mock_rpc + request = {} + client.batch_read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_read_tensorboard_time_series_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9700,6 +11990,52 @@ async def test_batch_read_tensorboard_time_series_data_empty_call_async(): ) +@pytest.mark.asyncio +async def test_batch_read_tensorboard_time_series_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_read_tensorboard_time_series_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_read_tensorboard_time_series_data + ] = mock_object + + request = {} + await client.batch_read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_read_tensorboard_time_series_data_async( transport: str = "grpc_asyncio", @@ -9947,6 +12283,9 @@ def test_read_tensorboard_time_series_data_empty_call(): with mock.patch.object( type(client.transport.read_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_time_series_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9973,6 +12312,9 @@ def test_read_tensorboard_time_series_data_non_empty_request_with_auto_populated with mock.patch.object( type(client.transport.read_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_time_series_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9982,6 +12324,46 @@ def test_read_tensorboard_time_series_data_non_empty_request_with_auto_populated ) +def test_read_tensorboard_time_series_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_time_series_data + ] = mock_rpc + request = {} + client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_time_series_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10005,6 +12387,52 @@ async def test_read_tensorboard_time_series_data_empty_call_async(): assert args[0] == tensorboard_service.ReadTensorboardTimeSeriesDataRequest() +@pytest.mark.asyncio +async def test_read_tensorboard_time_series_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_tensorboard_time_series_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_tensorboard_time_series_data + ] = mock_object + + request = {} + await client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_time_series_data_async( transport: str = "grpc_asyncio", @@ -10247,6 +12675,9 @@ def test_read_tensorboard_blob_data_empty_call(): with mock.patch.object( type(client.transport.read_tensorboard_blob_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_blob_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10272,6 +12703,9 @@ def test_read_tensorboard_blob_data_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.read_tensorboard_blob_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_blob_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10280,6 +12714,46 @@ def test_read_tensorboard_blob_data_non_empty_request_with_auto_populated_field( ) +def test_read_tensorboard_blob_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_blob_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_blob_data + ] = mock_rpc + request = {} + client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_blob_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_blob_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10304,6 +12778,52 @@ async def test_read_tensorboard_blob_data_empty_call_async(): assert args[0] == tensorboard_service.ReadTensorboardBlobDataRequest() +@pytest.mark.asyncio +async def test_read_tensorboard_blob_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_tensorboard_blob_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_tensorboard_blob_data + ] = mock_object + + request = {} + await client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_tensorboard_blob_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_blob_data_async( transport: str = "grpc_asyncio", @@ -10550,6 +13070,9 @@ def test_write_tensorboard_experiment_data_empty_call(): with mock.patch.object( type(client.transport.write_tensorboard_experiment_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_tensorboard_experiment_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10575,6 +13098,9 @@ def test_write_tensorboard_experiment_data_non_empty_request_with_auto_populated with mock.patch.object( type(client.transport.write_tensorboard_experiment_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_tensorboard_experiment_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10583,6 +13109,46 @@ def test_write_tensorboard_experiment_data_non_empty_request_with_auto_populated ) +def test_write_tensorboard_experiment_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_tensorboard_experiment_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_tensorboard_experiment_data + ] = mock_rpc + request = {} + client.write_tensorboard_experiment_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_tensorboard_experiment_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_write_tensorboard_experiment_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10606,6 +13172,52 @@ async def test_write_tensorboard_experiment_data_empty_call_async(): assert args[0] == tensorboard_service.WriteTensorboardExperimentDataRequest() +@pytest.mark.asyncio +async def test_write_tensorboard_experiment_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.write_tensorboard_experiment_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.write_tensorboard_experiment_data + ] = mock_object + + request = {} + await client.write_tensorboard_experiment_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.write_tensorboard_experiment_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_write_tensorboard_experiment_data_async( transport: str = "grpc_asyncio", @@ -10879,6 +13491,9 @@ def test_write_tensorboard_run_data_empty_call(): with mock.patch.object( type(client.transport.write_tensorboard_run_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_tensorboard_run_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10904,6 +13519,9 @@ def test_write_tensorboard_run_data_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.write_tensorboard_run_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_tensorboard_run_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10912,6 +13530,46 @@ def test_write_tensorboard_run_data_non_empty_request_with_auto_populated_field( ) +def test_write_tensorboard_run_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_tensorboard_run_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_tensorboard_run_data + ] = mock_rpc + request = {} + client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_tensorboard_run_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_write_tensorboard_run_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10935,6 +13593,52 @@ async def test_write_tensorboard_run_data_empty_call_async(): assert args[0] == tensorboard_service.WriteTensorboardRunDataRequest() +@pytest.mark.asyncio +async def test_write_tensorboard_run_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.write_tensorboard_run_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.write_tensorboard_run_data + ] = mock_object + + request = {} + await client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.write_tensorboard_run_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_write_tensorboard_run_data_async( transport: str = "grpc_asyncio", @@ -11209,6 +13913,9 @@ def test_export_tensorboard_time_series_data_empty_call(): with mock.patch.object( type(client.transport.export_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_tensorboard_time_series_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11237,6 +13944,9 @@ def test_export_tensorboard_time_series_data_non_empty_request_with_auto_populat with mock.patch.object( type(client.transport.export_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_tensorboard_time_series_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11248,6 +13958,46 @@ def test_export_tensorboard_time_series_data_non_empty_request_with_auto_populat ) +def test_export_tensorboard_time_series_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.export_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.export_tensorboard_time_series_data + ] = mock_rpc + request = {} + client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.export_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_tensorboard_time_series_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11273,6 +14023,52 @@ async def test_export_tensorboard_time_series_data_empty_call_async(): assert args[0] == tensorboard_service.ExportTensorboardTimeSeriesDataRequest() +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_tensorboard_time_series_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_tensorboard_time_series_data + ] = mock_object + + request = {} + await client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.export_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_export_tensorboard_time_series_data_async( transport: str = "grpc_asyncio", @@ -11789,6 +14585,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_tensorboard_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard + ] = mock_rpc + + request = {} + client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tensorboard_rest_required_fields( request_type=tensorboard_service.CreateTensorboardRequest, ): @@ -12076,6 +14916,42 @@ def test_get_tensorboard_rest(request_type): assert response.is_default is True +def test_get_tensorboard_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_tensorboard in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_tensorboard] = mock_rpc + + request = {} + client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tensorboard_rest_required_fields( request_type=tensorboard_service.GetTensorboardRequest, ): @@ -12424,6 +15300,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_tensorboard_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard + ] = mock_rpc + + request = {} + client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_tensorboard_rest_required_fields( request_type=tensorboard_service.UpdateTensorboardRequest, ): @@ -12704,6 +15624,44 @@ def test_list_tensorboards_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tensorboards_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_tensorboards in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboards + ] = mock_rpc + + request = {} + client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboards(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tensorboards_rest_required_fields( request_type=tensorboard_service.ListTensorboardsRequest, ): @@ -13046,6 +16004,50 @@ def test_delete_tensorboard_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_tensorboard_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard + ] = mock_rpc + + request = {} + client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_tensorboard_rest_required_fields( request_type=tensorboard_service.DeleteTensorboardRequest, ): @@ -13311,6 +16313,47 @@ def test_read_tensorboard_usage_rest(request_type): assert isinstance(response, tensorboard_service.ReadTensorboardUsageResponse) +def test_read_tensorboard_usage_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_usage + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_usage + ] = mock_rpc + + request = {} + client.read_tensorboard_usage(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_usage(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_tensorboard_usage_rest_required_fields( request_type=tensorboard_service.ReadTensorboardUsageRequest, ): @@ -13589,6 +16632,47 @@ def test_read_tensorboard_size_rest(request_type): assert response.storage_size_byte == 1826 +def test_read_tensorboard_size_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_size + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_size + ] = mock_rpc + + request = {} + client.read_tensorboard_size(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_size(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_tensorboard_size_rest_required_fields( request_type=tensorboard_service.ReadTensorboardSizeRequest, ): @@ -13953,6 +17037,47 @@ def get_message_fields(field): assert response.source == "source_value" +def test_create_tensorboard_experiment_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_experiment + ] = mock_rpc + + request = {} + client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tensorboard_experiment_rest_required_fields( request_type=tensorboard_service.CreateTensorboardExperimentRequest, ): @@ -14278,6 +17403,47 @@ def test_get_tensorboard_experiment_rest(request_type): assert response.source == "source_value" +def test_get_tensorboard_experiment_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_experiment + ] = mock_rpc + + request = {} + client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tensorboard_experiment_rest_required_fields( request_type=tensorboard_service.GetTensorboardExperimentRequest, ): @@ -14645,6 +17811,47 @@ def get_message_fields(field): assert response.source == "source_value" +def test_update_tensorboard_experiment_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_experiment + ] = mock_rpc + + request = {} + client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_tensorboard_experiment_rest_required_fields( request_type=tensorboard_service.UpdateTensorboardExperimentRequest, ): @@ -14943,6 +18150,47 @@ def test_list_tensorboard_experiments_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tensorboard_experiments_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_experiments + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_experiments + ] = mock_rpc + + request = {} + client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_experiments(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tensorboard_experiments_rest_required_fields( request_type=tensorboard_service.ListTensorboardExperimentsRequest, ): @@ -15300,6 +18548,51 @@ def test_delete_tensorboard_experiment_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_tensorboard_experiment_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_experiment + ] = mock_rpc + + request = {} + client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_tensorboard_experiment_rest_required_fields( request_type=tensorboard_service.DeleteTensorboardExperimentRequest, ): @@ -15659,6 +18952,47 @@ def get_message_fields(field): assert response.etag == "etag_value" +def test_create_tensorboard_run_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_run + ] = mock_rpc + + request = {} + client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tensorboard_run_rest_required_fields( request_type=tensorboard_service.CreateTensorboardRunRequest, ): @@ -15960,6 +19294,47 @@ def test_batch_create_tensorboard_runs_rest(request_type): assert isinstance(response, tensorboard_service.BatchCreateTensorboardRunsResponse) +def test_batch_create_tensorboard_runs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_tensorboard_runs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_tensorboard_runs + ] = mock_rpc + + request = {} + client.batch_create_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_create_tensorboard_runs_rest_required_fields( request_type=tensorboard_service.BatchCreateTensorboardRunsRequest, ): @@ -16265,6 +19640,46 @@ def test_get_tensorboard_run_rest(request_type): assert response.etag == "etag_value" +def test_get_tensorboard_run_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_run in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_run + ] = mock_rpc + + request = {} + client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tensorboard_run_rest_required_fields( request_type=tensorboard_service.GetTensorboardRunRequest, ): @@ -16624,6 +20039,47 @@ def get_message_fields(field): assert response.etag == "etag_value" +def test_update_tensorboard_run_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_run + ] = mock_rpc + + request = {} + client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_tensorboard_run_rest_required_fields( request_type=tensorboard_service.UpdateTensorboardRunRequest, ): @@ -16910,6 +20366,47 @@ def test_list_tensorboard_runs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tensorboard_runs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_runs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_runs + ] = mock_rpc + + request = {} + client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tensorboard_runs_rest_required_fields( request_type=tensorboard_service.ListTensorboardRunsRequest, ): @@ -17262,6 +20759,51 @@ def test_delete_tensorboard_run_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_tensorboard_run_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_run + ] = mock_rpc + + request = {} + client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_tensorboard_run_rest_required_fields( request_type=tensorboard_service.DeleteTensorboardRunRequest, ): @@ -17534,6 +21076,47 @@ def test_batch_create_tensorboard_time_series_rest(request_type): ) +def test_batch_create_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_tensorboard_time_series + ] = mock_rpc + + request = {} + client.batch_create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_create_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.BatchCreateTensorboardTimeSeriesRequest, ): @@ -17947,6 +21530,47 @@ def get_message_fields(field): assert response.plugin_data == b"plugin_data_blob" +def test_create_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_time_series + ] = mock_rpc + + request = {} + client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.CreateTensorboardTimeSeriesRequest, ): @@ -18263,6 +21887,47 @@ def test_get_tensorboard_time_series_rest(request_type): assert response.plugin_data == b"plugin_data_blob" +def test_get_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_time_series + ] = mock_rpc + + request = {} + client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.GetTensorboardTimeSeriesRequest, ): @@ -18647,6 +22312,47 @@ def get_message_fields(field): assert response.plugin_data == b"plugin_data_blob" +def test_update_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_time_series + ] = mock_rpc + + request = {} + client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.UpdateTensorboardTimeSeriesRequest, ): @@ -18949,6 +22655,47 @@ def test_list_tensorboard_time_series_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_time_series + ] = mock_rpc + + request = {} + client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.ListTensorboardTimeSeriesRequest, ): @@ -19309,6 +23056,51 @@ def test_delete_tensorboard_time_series_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_time_series + ] = mock_rpc + + request = {} + client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.DeleteTensorboardTimeSeriesRequest, ): @@ -19587,6 +23379,47 @@ def test_batch_read_tensorboard_time_series_data_rest(request_type): ) +def test_batch_read_tensorboard_time_series_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_read_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_read_tensorboard_time_series_data + ] = mock_rpc + + request = {} + client.batch_read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_read_tensorboard_time_series_data_rest_required_fields( request_type=tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest, ): @@ -19906,6 +23739,47 @@ def test_read_tensorboard_time_series_data_rest(request_type): ) +def test_read_tensorboard_time_series_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_time_series_data + ] = mock_rpc + + request = {} + client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_tensorboard_time_series_data_rest_required_fields( request_type=tensorboard_service.ReadTensorboardTimeSeriesDataRequest, ): @@ -20213,6 +24087,47 @@ def test_read_tensorboard_blob_data_rest(request_type): assert isinstance(response, tensorboard_service.ReadTensorboardBlobDataResponse) +def test_read_tensorboard_blob_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_blob_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_blob_data + ] = mock_rpc + + request = {} + client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_blob_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_tensorboard_blob_data_rest_required_fields( request_type=tensorboard_service.ReadTensorboardBlobDataRequest, ): @@ -20503,6 +24418,47 @@ def test_write_tensorboard_experiment_data_rest(request_type): ) +def test_write_tensorboard_experiment_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_tensorboard_experiment_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_tensorboard_experiment_data + ] = mock_rpc + + request = {} + client.write_tensorboard_experiment_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_tensorboard_experiment_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_write_tensorboard_experiment_data_rest_required_fields( request_type=tensorboard_service.WriteTensorboardExperimentDataRequest, ): @@ -20809,6 +24765,47 @@ def test_write_tensorboard_run_data_rest(request_type): assert isinstance(response, tensorboard_service.WriteTensorboardRunDataResponse) +def test_write_tensorboard_run_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_tensorboard_run_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_tensorboard_run_data + ] = mock_rpc + + request = {} + client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_tensorboard_run_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_write_tensorboard_run_data_rest_required_fields( request_type=tensorboard_service.WriteTensorboardRunDataRequest, ): @@ -21110,6 +25107,47 @@ def test_export_tensorboard_time_series_data_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_export_tensorboard_time_series_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.export_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.export_tensorboard_time_series_data + ] = mock_rpc + + request = {} + client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.export_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_export_tensorboard_time_series_data_rest_required_fields( request_type=tensorboard_service.ExportTensorboardTimeSeriesDataRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1/test_vizier_service.py index fb1ca2f142..614f98c528 100644 --- a/tests/unit/gapic/aiplatform_v1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_vizier_service.py @@ -1182,6 +1182,9 @@ def test_create_study_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_study() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1205,6 +1208,9 @@ def test_create_study_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_study(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1213,6 +1219,41 @@ def test_create_study_non_empty_request_with_auto_populated_field(): ) +def test_create_study_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_study] = mock_rpc + request = {} + client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_study_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1239,6 +1280,52 @@ async def test_create_study_empty_call_async(): assert args[0] == vizier_service.CreateStudyRequest() +@pytest.mark.asyncio +async def test_create_study_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_study + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_study + ] = mock_object + + request = {} + await client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_study_async( transport: str = "grpc_asyncio", request_type=vizier_service.CreateStudyRequest @@ -1485,6 +1572,9 @@ def test_get_study_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_study() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1508,6 +1598,9 @@ def test_get_study_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_study(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1516,6 +1609,41 @@ def test_get_study_non_empty_request_with_auto_populated_field(): ) +def test_get_study_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_study] = mock_rpc + request = {} + client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_study_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1542,6 +1670,50 @@ async def test_get_study_empty_call_async(): assert args[0] == vizier_service.GetStudyRequest() +@pytest.mark.asyncio +async def test_get_study_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_study + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_study + ] = mock_object + + request = {} + await client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_study_async( transport: str = "grpc_asyncio", request_type=vizier_service.GetStudyRequest @@ -1772,6 +1944,9 @@ def test_list_studies_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_studies() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1796,6 +1971,9 @@ def test_list_studies_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_studies(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1805,6 +1983,41 @@ def test_list_studies_non_empty_request_with_auto_populated_field(): ) +def test_list_studies_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_studies in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_studies] = mock_rpc + request = {} + client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_studies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_studies_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1828,6 +2041,52 @@ async def test_list_studies_empty_call_async(): assert args[0] == vizier_service.ListStudiesRequest() +@pytest.mark.asyncio +async def test_list_studies_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_studies + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_studies + ] = mock_object + + request = {} + await client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_studies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_studies_async( transport: str = "grpc_asyncio", request_type=vizier_service.ListStudiesRequest @@ -2243,6 +2502,9 @@ def test_delete_study_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_study() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2266,6 +2528,9 @@ def test_delete_study_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_study(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2274,6 +2539,41 @@ def test_delete_study_non_empty_request_with_auto_populated_field(): ) +def test_delete_study_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_study] = mock_rpc + request = {} + client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_study_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2293,6 +2593,52 @@ async def test_delete_study_empty_call_async(): assert args[0] == vizier_service.DeleteStudyRequest() +@pytest.mark.asyncio +async def test_delete_study_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_study + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_study + ] = mock_object + + request = {} + await client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.delete_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_study_async( transport: str = "grpc_asyncio", request_type=vizier_service.DeleteStudyRequest @@ -2518,6 +2864,9 @@ def test_lookup_study_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.lookup_study() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2542,6 +2891,9 @@ def test_lookup_study_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.lookup_study(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2551,6 +2903,41 @@ def test_lookup_study_non_empty_request_with_auto_populated_field(): ) +def test_lookup_study_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.lookup_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.lookup_study] = mock_rpc + request = {} + client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.lookup_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_lookup_study_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2577,6 +2964,52 @@ async def test_lookup_study_empty_call_async(): assert args[0] == vizier_service.LookupStudyRequest() +@pytest.mark.asyncio +async def test_lookup_study_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.lookup_study + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.lookup_study + ] = mock_object + + request = {} + await client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.lookup_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_lookup_study_async( transport: str = "grpc_asyncio", request_type=vizier_service.LookupStudyRequest @@ -2804,6 +3237,9 @@ def test_suggest_trials_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.suggest_trials() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2828,6 +3264,9 @@ def test_suggest_trials_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.suggest_trials(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2837,6 +3276,45 @@ def test_suggest_trials_non_empty_request_with_auto_populated_field(): ) +def test_suggest_trials_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.suggest_trials in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.suggest_trials] = mock_rpc + request = {} + client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.suggest_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_suggest_trials_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2858,6 +3336,56 @@ async def test_suggest_trials_empty_call_async(): assert args[0] == vizier_service.SuggestTrialsRequest() +@pytest.mark.asyncio +async def test_suggest_trials_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.suggest_trials + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.suggest_trials + ] = mock_object + + request = {} + await client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.suggest_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_suggest_trials_async( transport: str = "grpc_asyncio", request_type=vizier_service.SuggestTrialsRequest @@ -3011,6 +3539,9 @@ def test_create_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3034,6 +3565,9 @@ def test_create_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3042,6 +3576,41 @@ def test_create_trial_non_empty_request_with_auto_populated_field(): ) +def test_create_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_trial] = mock_rpc + request = {} + client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3071,15 +3640,61 @@ async def test_create_trial_empty_call_async(): @pytest.mark.asyncio -async def test_create_trial_async( - transport: str = "grpc_asyncio", request_type=vizier_service.CreateTrialRequest +async def test_create_trial_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = VizierServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Everything is optional in proto3 as far as the runtime is concerned, + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_trial + ] = mock_object + + request = {} + await client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.CreateTrialRequest +): + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. request = request_type() @@ -3324,6 +3939,9 @@ def test_get_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3347,6 +3965,9 @@ def test_get_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3355,6 +3976,41 @@ def test_get_trial_non_empty_request_with_auto_populated_field(): ) +def test_get_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_trial] = mock_rpc + request = {} + client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3383,6 +4039,50 @@ async def test_get_trial_empty_call_async(): assert args[0] == vizier_service.GetTrialRequest() +@pytest.mark.asyncio +async def test_get_trial_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_trial + ] = mock_object + + request = {} + await client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_trial_async( transport: str = "grpc_asyncio", request_type=vizier_service.GetTrialRequest @@ -3617,6 +4317,9 @@ def test_list_trials_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_trials() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3641,6 +4344,9 @@ def test_list_trials_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_trials(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3650,6 +4356,41 @@ def test_list_trials_non_empty_request_with_auto_populated_field(): ) +def test_list_trials_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_trials in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_trials] = mock_rpc + request = {} + client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_trials_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3673,6 +4414,52 @@ async def test_list_trials_empty_call_async(): assert args[0] == vizier_service.ListTrialsRequest() +@pytest.mark.asyncio +async def test_list_trials_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_trials + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_trials + ] = mock_object + + request = {} + await client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_trials_async( transport: str = "grpc_asyncio", request_type=vizier_service.ListTrialsRequest @@ -4105,6 +4892,9 @@ def test_add_trial_measurement_empty_call(): with mock.patch.object( type(client.transport.add_trial_measurement), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_trial_measurement() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4130,6 +4920,9 @@ def test_add_trial_measurement_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.add_trial_measurement), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_trial_measurement(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4138,6 +4931,46 @@ def test_add_trial_measurement_non_empty_request_with_auto_populated_field(): ) +def test_add_trial_measurement_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_trial_measurement + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_trial_measurement + ] = mock_rpc + request = {} + client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_trial_measurement(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_add_trial_measurement_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4168,6 +5001,52 @@ async def test_add_trial_measurement_empty_call_async(): assert args[0] == vizier_service.AddTrialMeasurementRequest() +@pytest.mark.asyncio +async def test_add_trial_measurement_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.add_trial_measurement + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.add_trial_measurement + ] = mock_object + + request = {} + await client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.add_trial_measurement(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_add_trial_measurement_async( transport: str = "grpc_asyncio", @@ -4339,6 +5218,9 @@ def test_complete_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.complete_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4363,6 +5245,9 @@ def test_complete_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.complete_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4372,6 +5257,41 @@ def test_complete_trial_non_empty_request_with_auto_populated_field(): ) +def test_complete_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.complete_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.complete_trial] = mock_rpc + request = {} + client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.complete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_complete_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4400,6 +5320,52 @@ async def test_complete_trial_empty_call_async(): assert args[0] == vizier_service.CompleteTrialRequest() +@pytest.mark.asyncio +async def test_complete_trial_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.complete_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.complete_trial + ] = mock_object + + request = {} + await client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.complete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_complete_trial_async( transport: str = "grpc_asyncio", request_type=vizier_service.CompleteTrialRequest @@ -4551,6 +5517,9 @@ def test_delete_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4574,6 +5543,9 @@ def test_delete_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4582,6 +5554,41 @@ def test_delete_trial_non_empty_request_with_auto_populated_field(): ) +def test_delete_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_trial] = mock_rpc + request = {} + client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4601,6 +5608,52 @@ async def test_delete_trial_empty_call_async(): assert args[0] == vizier_service.DeleteTrialRequest() +@pytest.mark.asyncio +async def test_delete_trial_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_trial + ] = mock_object + + request = {} + await client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.delete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_trial_async( transport: str = "grpc_asyncio", request_type=vizier_service.DeleteTrialRequest @@ -4821,6 +5874,9 @@ def test_check_trial_early_stopping_state_empty_call(): with mock.patch.object( type(client.transport.check_trial_early_stopping_state), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.check_trial_early_stopping_state() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4846,6 +5902,9 @@ def test_check_trial_early_stopping_state_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.check_trial_early_stopping_state), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.check_trial_early_stopping_state(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4854,6 +5913,50 @@ def test_check_trial_early_stopping_state_non_empty_request_with_auto_populated_ ) +def test_check_trial_early_stopping_state_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.check_trial_early_stopping_state + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.check_trial_early_stopping_state + ] = mock_rpc + request = {} + client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.check_trial_early_stopping_state(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_check_trial_early_stopping_state_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4877,6 +5980,56 @@ async def test_check_trial_early_stopping_state_empty_call_async(): assert args[0] == vizier_service.CheckTrialEarlyStoppingStateRequest() +@pytest.mark.asyncio +async def test_check_trial_early_stopping_state_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.check_trial_early_stopping_state + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.check_trial_early_stopping_state + ] = mock_object + + request = {} + await client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.check_trial_early_stopping_state(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_check_trial_early_stopping_state_async( transport: str = "grpc_asyncio", @@ -5037,6 +6190,9 @@ def test_stop_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stop_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5060,6 +6216,9 @@ def test_stop_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stop_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5068,6 +6227,41 @@ def test_stop_trial_non_empty_request_with_auto_populated_field(): ) +def test_stop_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.stop_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.stop_trial] = mock_rpc + request = {} + client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stop_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_stop_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5096,6 +6290,50 @@ async def test_stop_trial_empty_call_async(): assert args[0] == vizier_service.StopTrialRequest() +@pytest.mark.asyncio +async def test_stop_trial_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stop_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stop_trial + ] = mock_object + + request = {} + await client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stop_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stop_trial_async( transport: str = "grpc_asyncio", request_type=vizier_service.StopTrialRequest @@ -5251,6 +6489,9 @@ def test_list_optimal_trials_empty_call(): with mock.patch.object( type(client.transport.list_optimal_trials), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_optimal_trials() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5276,6 +6517,9 @@ def test_list_optimal_trials_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_optimal_trials), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_optimal_trials(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5284,6 +6528,45 @@ def test_list_optimal_trials_non_empty_request_with_auto_populated_field(): ) +def test_list_optimal_trials_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_optimal_trials in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_optimal_trials + ] = mock_rpc + request = {} + client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_optimal_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_optimal_trials_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5307,6 +6590,52 @@ async def test_list_optimal_trials_empty_call_async(): assert args[0] == vizier_service.ListOptimalTrialsRequest() +@pytest.mark.asyncio +async def test_list_optimal_trials_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_optimal_trials + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_optimal_trials + ] = mock_object + + request = {} + await client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_optimal_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_optimal_trials_async( transport: str = "grpc_asyncio", @@ -5688,6 +7017,42 @@ def get_message_fields(field): assert response.inactive_reason == "inactive_reason_value" +def test_create_study_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_study] = mock_rpc + + request = {} + client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_study_rest_required_fields( request_type=vizier_service.CreateStudyRequest, ): @@ -5969,6 +7334,42 @@ def test_get_study_rest(request_type): assert response.inactive_reason == "inactive_reason_value" +def test_get_study_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_study] = mock_rpc + + request = {} + client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_study_rest_required_fields(request_type=vizier_service.GetStudyRequest): transport_class = transports.VizierServiceRestTransport @@ -6229,6 +7630,42 @@ def test_list_studies_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_studies_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_studies in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_studies] = mock_rpc + + request = {} + client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_studies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_studies_rest_required_fields( request_type=vizier_service.ListStudiesRequest, ): @@ -6562,6 +7999,42 @@ def test_delete_study_rest(request_type): assert response is None +def test_delete_study_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_study] = mock_rpc + + request = {} + client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_study_rest_required_fields( request_type=vizier_service.DeleteStudyRequest, ): @@ -6821,6 +8294,42 @@ def test_lookup_study_rest(request_type): assert response.inactive_reason == "inactive_reason_value" +def test_lookup_study_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.lookup_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.lookup_study] = mock_rpc + + request = {} + client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.lookup_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_lookup_study_rest_required_fields( request_type=vizier_service.LookupStudyRequest, ): @@ -7094,6 +8603,46 @@ def test_suggest_trials_rest(request_type): assert response.operation.name == "operations/spam" +def test_suggest_trials_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.suggest_trials in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.suggest_trials] = mock_rpc + + request = {} + client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.suggest_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_suggest_trials_rest_required_fields( request_type=vizier_service.SuggestTrialsRequest, ): @@ -7428,6 +8977,42 @@ def get_message_fields(field): assert response.custom_job == "custom_job_value" +def test_create_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_trial] = mock_rpc + + request = {} + client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_trial_rest_required_fields( request_type=vizier_service.CreateTrialRequest, ): @@ -7718,6 +9303,42 @@ def test_get_trial_rest(request_type): assert response.custom_job == "custom_job_value" +def test_get_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_trial] = mock_rpc + + request = {} + client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_trial_rest_required_fields(request_type=vizier_service.GetTrialRequest): transport_class = transports.VizierServiceRestTransport @@ -7983,6 +9604,42 @@ def test_list_trials_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_trials_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_trials in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_trials] = mock_rpc + + request = {} + client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_trials_rest_required_fields( request_type=vizier_service.ListTrialsRequest, ): @@ -8336,6 +9993,47 @@ def test_add_trial_measurement_rest(request_type): assert response.custom_job == "custom_job_value" +def test_add_trial_measurement_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_trial_measurement + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_trial_measurement + ] = mock_rpc + + request = {} + client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_trial_measurement(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_add_trial_measurement_rest_required_fields( request_type=vizier_service.AddTrialMeasurementRequest, ): @@ -8567,6 +10265,42 @@ def test_complete_trial_rest(request_type): assert response.custom_job == "custom_job_value" +def test_complete_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.complete_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.complete_trial] = mock_rpc + + request = {} + client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.complete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_complete_trial_rest_required_fields( request_type=vizier_service.CompleteTrialRequest, ): @@ -8775,6 +10509,42 @@ def test_delete_trial_rest(request_type): assert response is None +def test_delete_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_trial] = mock_rpc + + request = {} + client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_trial_rest_required_fields( request_type=vizier_service.DeleteTrialRequest, ): @@ -9030,6 +10800,51 @@ def test_check_trial_early_stopping_state_rest(request_type): assert response.operation.name == "operations/spam" +def test_check_trial_early_stopping_state_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.check_trial_early_stopping_state + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.check_trial_early_stopping_state + ] = mock_rpc + + request = {} + client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.check_trial_early_stopping_state(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_check_trial_early_stopping_state_rest_required_fields( request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, ): @@ -9257,6 +11072,42 @@ def test_stop_trial_rest(request_type): assert response.custom_job == "custom_job_value" +def test_stop_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.stop_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.stop_trial] = mock_rpc + + request = {} + client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stop_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_stop_trial_rest_required_fields(request_type=vizier_service.StopTrialRequest): transport_class = transports.VizierServiceRestTransport @@ -9463,6 +11314,46 @@ def test_list_optimal_trials_rest(request_type): assert isinstance(response, vizier_service.ListOptimalTrialsResponse) +def test_list_optimal_trials_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_optimal_trials in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_optimal_trials + ] = mock_rpc + + request = {} + client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_optimal_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_optimal_trials_rest_required_fields( request_type=vizier_service.ListOptimalTrialsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index 5a75e56dc4..73868d3690 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -1190,6 +1190,9 @@ def test_create_dataset_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1213,6 +1216,9 @@ def test_create_dataset_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1221,6 +1227,45 @@ def test_create_dataset_non_empty_request_with_auto_populated_field(): ) +def test_create_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_dataset] = mock_rpc + request = {} + client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1242,6 +1287,56 @@ async def test_create_dataset_empty_call_async(): assert args[0] == dataset_service.CreateDatasetRequest() +@pytest.mark.asyncio +async def test_create_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_dataset + ] = mock_object + + request = {} + await client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest @@ -1489,6 +1584,9 @@ def test_get_dataset_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1512,6 +1610,9 @@ def test_get_dataset_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1520,6 +1621,41 @@ def test_get_dataset_non_empty_request_with_auto_populated_field(): ) +def test_get_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_dataset] = mock_rpc + request = {} + client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1549,6 +1685,52 @@ async def test_get_dataset_empty_call_async(): assert args[0] == dataset_service.GetDatasetRequest() +@pytest.mark.asyncio +async def test_get_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_dataset + ] = mock_object + + request = {} + await client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest @@ -1797,6 +1979,9 @@ def test_update_dataset_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1818,12 +2003,50 @@ def test_update_dataset_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == dataset_service.UpdateDatasetRequest() +def test_update_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_dataset] = mock_rpc + request = {} + client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1853,6 +2076,52 @@ async def test_update_dataset_empty_call_async(): assert args[0] == dataset_service.UpdateDatasetRequest() +@pytest.mark.asyncio +async def test_update_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_dataset + ] = mock_object + + request = {} + await client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest @@ -2099,6 +2368,9 @@ def test_list_datasets_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_datasets() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2125,6 +2397,9 @@ def test_list_datasets_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_datasets(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2136,6 +2411,41 @@ def test_list_datasets_non_empty_request_with_auto_populated_field(): ) +def test_list_datasets_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_datasets in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_datasets] = mock_rpc + request = {} + client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_datasets(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_datasets_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2159,6 +2469,52 @@ async def test_list_datasets_empty_call_async(): assert args[0] == dataset_service.ListDatasetsRequest() +@pytest.mark.asyncio +async def test_list_datasets_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_datasets + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_datasets + ] = mock_object + + request = {} + await client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_datasets(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_datasets_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest @@ -2574,6 +2930,9 @@ def test_delete_dataset_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2597,6 +2956,9 @@ def test_delete_dataset_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2605,6 +2967,45 @@ def test_delete_dataset_non_empty_request_with_auto_populated_field(): ) +def test_delete_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_dataset] = mock_rpc + request = {} + client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2626,6 +3027,56 @@ async def test_delete_dataset_empty_call_async(): assert args[0] == dataset_service.DeleteDatasetRequest() +@pytest.mark.asyncio +async def test_delete_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_dataset + ] = mock_object + + request = {} + await client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest @@ -2848,6 +3299,9 @@ def test_import_data_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2871,6 +3325,9 @@ def test_import_data_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2879,6 +3336,45 @@ def test_import_data_non_empty_request_with_auto_populated_field(): ) +def test_import_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_data in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.import_data] = mock_rpc + request = {} + client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2900,6 +3396,56 @@ async def test_import_data_empty_call_async(): assert args[0] == dataset_service.ImportDataRequest() +@pytest.mark.asyncio +async def test_import_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_data + ] = mock_object + + request = {} + await client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.import_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_data_async( transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest @@ -3144,6 +3690,9 @@ def test_export_data_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3167,6 +3716,9 @@ def test_export_data_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3175,6 +3727,45 @@ def test_export_data_non_empty_request_with_auto_populated_field(): ) +def test_export_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_data in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_data] = mock_rpc + request = {} + client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3197,10 +3788,60 @@ async def test_export_data_empty_call_async(): @pytest.mark.asyncio -async def test_export_data_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest +async def test_export_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = DatasetServiceAsyncClient( + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_data + ] = mock_object + + request = {} + await client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.export_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_export_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest +): + client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3456,6 +4097,9 @@ def test_create_dataset_version_empty_call(): with mock.patch.object( type(client.transport.create_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3481,6 +4125,9 @@ def test_create_dataset_version_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_dataset_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3489,6 +4136,50 @@ def test_create_dataset_version_non_empty_request_with_auto_populated_field(): ) +def test_create_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_dataset_version + ] = mock_rpc + request = {} + client.create_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_dataset_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3512,6 +4203,56 @@ async def test_create_dataset_version_empty_call_async(): assert args[0] == dataset_service.CreateDatasetVersionRequest() +@pytest.mark.asyncio +async def test_create_dataset_version_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_dataset_version + ] = mock_object + + request = {} + await client.create_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_dataset_version_async( transport: str = "grpc_asyncio", @@ -3715,11 +4456,11 @@ async def test_create_dataset_version_flattened_error_async(): @pytest.mark.parametrize( "request_type", [ - dataset_service.DeleteDatasetVersionRequest, + dataset_service.UpdateDatasetVersionRequest, dict, ], ) -def test_delete_dataset_version(request_type, transport: str = "grpc"): +def test_update_dataset_version(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3731,23 +4472,32 @@ def test_delete_dataset_version(request_type, transport: str = "grpc"): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.delete_dataset_version(request) + call.return_value = gca_dataset_version.DatasetVersion( + name="name_value", + etag="etag_value", + big_query_dataset_name="big_query_dataset_name_value", + display_name="display_name_value", + ) + response = client.update_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = dataset_service.DeleteDatasetVersionRequest() + request = dataset_service.UpdateDatasetVersionRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, gca_dataset_version.DatasetVersion) + assert response.name == "name_value" + assert response.etag == "etag_value" + assert response.big_query_dataset_name == "big_query_dataset_name_value" + assert response.display_name == "display_name_value" -def test_delete_dataset_version_empty_call(): +def test_update_dataset_version_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( @@ -3757,15 +4507,18 @@ def test_delete_dataset_version_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: - client.delete_dataset_version() + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.DeleteDatasetVersionRequest() + assert args[0] == dataset_service.UpdateDatasetVersionRequest() -def test_delete_dataset_version_non_empty_request_with_auto_populated_field(): +def test_update_dataset_version_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = DatasetServiceClient( @@ -3776,24 +4529,63 @@ def test_delete_dataset_version_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = dataset_service.DeleteDatasetVersionRequest( - name="name_value", - ) + request = dataset_service.UpdateDatasetVersionRequest() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: - client.delete_dataset_version(request=request) + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.DeleteDatasetVersionRequest( - name="name_value", + assert args[0] == dataset_service.UpdateDatasetVersionRequest() + + +def test_update_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_dataset_version + in client._transport._wrapped_methods ) + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_dataset_version + ] = mock_rpc + request = {} + client.update_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + @pytest.mark.asyncio -async def test_delete_dataset_version_empty_call_async(): +async def test_update_dataset_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceAsyncClient( @@ -3803,22 +4595,73 @@ async def test_delete_dataset_version_empty_call_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + gca_dataset_version.DatasetVersion( + name="name_value", + etag="etag_value", + big_query_dataset_name="big_query_dataset_name_value", + display_name="display_name_value", + ) ) - response = await client.delete_dataset_version() + response = await client.update_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.DeleteDatasetVersionRequest() + assert args[0] == dataset_service.UpdateDatasetVersionRequest() @pytest.mark.asyncio -async def test_delete_dataset_version_async( +async def test_update_dataset_version_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=dataset_service.DeleteDatasetVersionRequest, +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_dataset_version + ] = mock_object + + request = {} + await client.update_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_dataset_version_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.UpdateDatasetVersionRequest, ): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3831,46 +4674,55 @@ async def test_delete_dataset_version_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + gca_dataset_version.DatasetVersion( + name="name_value", + etag="etag_value", + big_query_dataset_name="big_query_dataset_name_value", + display_name="display_name_value", + ) ) - response = await client.delete_dataset_version(request) + response = await client.update_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.DeleteDatasetVersionRequest() + request = dataset_service.UpdateDatasetVersionRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, gca_dataset_version.DatasetVersion) + assert response.name == "name_value" + assert response.etag == "etag_value" + assert response.big_query_dataset_name == "big_query_dataset_name_value" + assert response.display_name == "display_name_value" @pytest.mark.asyncio -async def test_delete_dataset_version_async_from_dict(): - await test_delete_dataset_version_async(request_type=dict) +async def test_update_dataset_version_async_from_dict(): + await test_update_dataset_version_async(request_type=dict) -def test_delete_dataset_version_field_headers(): +def test_update_dataset_version_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.DeleteDatasetVersionRequest() + request = dataset_service.UpdateDatasetVersionRequest() - request.name = "name_value" + request.dataset_version.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - client.delete_dataset_version(request) + call.return_value = gca_dataset_version.DatasetVersion() + client.update_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -3881,30 +4733,30 @@ def test_delete_dataset_version_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "dataset_version.name=name_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_delete_dataset_version_field_headers_async(): +async def test_update_dataset_version_field_headers_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.DeleteDatasetVersionRequest() + request = dataset_service.UpdateDatasetVersionRequest() - request.name = "name_value" + request.dataset_version.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") + gca_dataset_version.DatasetVersion() ) - await client.delete_dataset_version(request) + await client.update_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -3915,37 +4767,41 @@ async def test_delete_dataset_version_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "dataset_version.name=name_value", ) in kw["metadata"] -def test_delete_dataset_version_flattened(): +def test_update_dataset_version_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = gca_dataset_version.DatasetVersion() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_dataset_version( - name="name_value", + client.update_dataset_version( + dataset_version=gca_dataset_version.DatasetVersion(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].dataset_version + mock_val = gca_dataset_version.DatasetVersion(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) assert arg == mock_val -def test_delete_dataset_version_flattened_error(): +def test_update_dataset_version_flattened_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3953,45 +4809,50 @@ def test_delete_dataset_version_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_dataset_version( - dataset_service.DeleteDatasetVersionRequest(), - name="name_value", + client.update_dataset_version( + dataset_service.UpdateDatasetVersionRequest(), + dataset_version=gca_dataset_version.DatasetVersion(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio -async def test_delete_dataset_version_flattened_async(): +async def test_update_dataset_version_flattened_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_dataset_version), "__call__" + type(client.transport.update_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = gca_dataset_version.DatasetVersion() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + gca_dataset_version.DatasetVersion() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_dataset_version( - name="name_value", + response = await client.update_dataset_version( + dataset_version=gca_dataset_version.DatasetVersion(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].dataset_version + mock_val = gca_dataset_version.DatasetVersion(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) assert arg == mock_val @pytest.mark.asyncio -async def test_delete_dataset_version_flattened_error_async(): +async def test_update_dataset_version_flattened_error_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3999,20 +4860,21 @@ async def test_delete_dataset_version_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.delete_dataset_version( - dataset_service.DeleteDatasetVersionRequest(), - name="name_value", + await client.update_dataset_version( + dataset_service.UpdateDatasetVersionRequest(), + dataset_version=gca_dataset_version.DatasetVersion(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) @pytest.mark.parametrize( "request_type", [ - dataset_service.GetDatasetVersionRequest, + dataset_service.DeleteDatasetVersionRequest, dict, ], ) -def test_get_dataset_version(request_type, transport: str = "grpc"): +def test_delete_dataset_version(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4024,32 +4886,23 @@ def test_get_dataset_version(request_type, transport: str = "grpc"): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_version.DatasetVersion( - name="name_value", - etag="etag_value", - big_query_dataset_name="big_query_dataset_name_value", - display_name="display_name_value", - ) - response = client.get_dataset_version(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = dataset_service.GetDatasetVersionRequest() + request = dataset_service.DeleteDatasetVersionRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, dataset_version.DatasetVersion) - assert response.name == "name_value" - assert response.etag == "etag_value" - assert response.big_query_dataset_name == "big_query_dataset_name_value" - assert response.display_name == "display_name_value" + assert isinstance(response, future.Future) -def test_get_dataset_version_empty_call(): +def test_delete_dataset_version_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( @@ -4059,15 +4912,18 @@ def test_get_dataset_version_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: - client.get_dataset_version() + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.GetDatasetVersionRequest() + assert args[0] == dataset_service.DeleteDatasetVersionRequest() -def test_get_dataset_version_non_empty_request_with_auto_populated_field(): +def test_delete_dataset_version_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = DatasetServiceClient( @@ -4078,24 +4934,71 @@ def test_get_dataset_version_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = dataset_service.GetDatasetVersionRequest( + request = dataset_service.DeleteDatasetVersionRequest( name="name_value", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: - client.get_dataset_version(request=request) + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.GetDatasetVersionRequest( + assert args[0] == dataset_service.DeleteDatasetVersionRequest( name="name_value", ) -@pytest.mark.asyncio -async def test_get_dataset_version_empty_call_async(): +def test_delete_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_dataset_version + ] = mock_rpc + request = {} + client.delete_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_dataset_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceAsyncClient( @@ -4105,27 +5008,72 @@ async def test_get_dataset_version_empty_call_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_version.DatasetVersion( - name="name_value", - etag="etag_value", - big_query_dataset_name="big_query_dataset_name_value", - display_name="display_name_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.get_dataset_version() + response = await client.delete_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.GetDatasetVersionRequest() + assert args[0] == dataset_service.DeleteDatasetVersionRequest() @pytest.mark.asyncio -async def test_get_dataset_version_async( +async def test_delete_dataset_version_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=dataset_service.GetDatasetVersionRequest, +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_dataset_version + ] = mock_object + + request = {} + await client.delete_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_dataset_version_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.DeleteDatasetVersionRequest, ): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -4138,55 +5086,46 @@ async def test_get_dataset_version_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_version.DatasetVersion( - name="name_value", - etag="etag_value", - big_query_dataset_name="big_query_dataset_name_value", - display_name="display_name_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.get_dataset_version(request) + response = await client.delete_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.GetDatasetVersionRequest() + request = dataset_service.DeleteDatasetVersionRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, dataset_version.DatasetVersion) - assert response.name == "name_value" - assert response.etag == "etag_value" - assert response.big_query_dataset_name == "big_query_dataset_name_value" - assert response.display_name == "display_name_value" + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_get_dataset_version_async_from_dict(): - await test_get_dataset_version_async(request_type=dict) +async def test_delete_dataset_version_async_from_dict(): + await test_delete_dataset_version_async(request_type=dict) -def test_get_dataset_version_field_headers(): +def test_delete_dataset_version_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.GetDatasetVersionRequest() + request = dataset_service.DeleteDatasetVersionRequest() request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: - call.return_value = dataset_version.DatasetVersion() - client.get_dataset_version(request) + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4202,25 +5141,25 @@ def test_get_dataset_version_field_headers(): @pytest.mark.asyncio -async def test_get_dataset_version_field_headers_async(): +async def test_delete_dataset_version_field_headers_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.GetDatasetVersionRequest() + request = dataset_service.DeleteDatasetVersionRequest() request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_version.DatasetVersion() + operations_pb2.Operation(name="operations/op") ) - await client.get_dataset_version(request) + await client.delete_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4235,20 +5174,20 @@ async def test_get_dataset_version_field_headers_async(): ) in kw["metadata"] -def test_get_dataset_version_flattened(): +def test_delete_dataset_version_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_version.DatasetVersion() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_dataset_version( + client.delete_dataset_version( name="name_value", ) @@ -4261,7 +5200,7 @@ def test_get_dataset_version_flattened(): assert arg == mock_val -def test_get_dataset_version_flattened_error(): +def test_delete_dataset_version_flattened_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4269,31 +5208,31 @@ def test_get_dataset_version_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_dataset_version( - dataset_service.GetDatasetVersionRequest(), + client.delete_dataset_version( + dataset_service.DeleteDatasetVersionRequest(), name="name_value", ) @pytest.mark.asyncio -async def test_get_dataset_version_flattened_async(): +async def test_delete_dataset_version_flattened_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_dataset_version), "__call__" + type(client.transport.delete_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_version.DatasetVersion() + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_version.DatasetVersion() + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_dataset_version( + response = await client.delete_dataset_version( name="name_value", ) @@ -4307,7 +5246,7 @@ async def test_get_dataset_version_flattened_async(): @pytest.mark.asyncio -async def test_get_dataset_version_flattened_error_async(): +async def test_delete_dataset_version_flattened_error_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4315,8 +5254,8 @@ async def test_get_dataset_version_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.get_dataset_version( - dataset_service.GetDatasetVersionRequest(), + await client.delete_dataset_version( + dataset_service.DeleteDatasetVersionRequest(), name="name_value", ) @@ -4324,11 +5263,11 @@ async def test_get_dataset_version_flattened_error_async(): @pytest.mark.parametrize( "request_type", [ - dataset_service.ListDatasetVersionsRequest, + dataset_service.GetDatasetVersionRequest, dict, ], ) -def test_list_dataset_versions(request_type, transport: str = "grpc"): +def test_get_dataset_version(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4340,26 +5279,32 @@ def test_list_dataset_versions(request_type, transport: str = "grpc"): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetVersionsResponse( - next_page_token="next_page_token_value", + call.return_value = dataset_version.DatasetVersion( + name="name_value", + etag="etag_value", + big_query_dataset_name="big_query_dataset_name_value", + display_name="display_name_value", ) - response = client.list_dataset_versions(request) + response = client.get_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = dataset_service.ListDatasetVersionsRequest() + request = dataset_service.GetDatasetVersionRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDatasetVersionsPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, dataset_version.DatasetVersion) + assert response.name == "name_value" + assert response.etag == "etag_value" + assert response.big_query_dataset_name == "big_query_dataset_name_value" + assert response.display_name == "display_name_value" -def test_list_dataset_versions_empty_call(): +def test_get_dataset_version_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( @@ -4369,15 +5314,18 @@ def test_list_dataset_versions_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: - client.list_dataset_versions() + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListDatasetVersionsRequest() + assert args[0] == dataset_service.GetDatasetVersionRequest() -def test_list_dataset_versions_non_empty_request_with_auto_populated_field(): +def test_get_dataset_version_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = DatasetServiceClient( @@ -4388,30 +5336,66 @@ def test_list_dataset_versions_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = dataset_service.ListDatasetVersionsRequest( - parent="parent_value", - filter="filter_value", - page_token="page_token_value", - order_by="order_by_value", + request = dataset_service.GetDatasetVersionRequest( + name="name_value", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: - client.list_dataset_versions(request=request) + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListDatasetVersionsRequest( - parent="parent_value", - filter="filter_value", - page_token="page_token_value", - order_by="order_by_value", + assert args[0] == dataset_service.GetDatasetVersionRequest( + name="name_value", + ) + + +def test_get_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", ) + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_dataset_version in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_dataset_version + ] = mock_rpc + request = {} + client.get_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + @pytest.mark.asyncio -async def test_list_dataset_versions_empty_call_async(): +async def test_get_dataset_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceAsyncClient( @@ -4421,24 +5405,73 @@ async def test_list_dataset_versions_empty_call_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetVersionsResponse( - next_page_token="next_page_token_value", + dataset_version.DatasetVersion( + name="name_value", + etag="etag_value", + big_query_dataset_name="big_query_dataset_name_value", + display_name="display_name_value", ) ) - response = await client.list_dataset_versions() + response = await client.get_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListDatasetVersionsRequest() + assert args[0] == dataset_service.GetDatasetVersionRequest() @pytest.mark.asyncio -async def test_list_dataset_versions_async( +async def test_get_dataset_version_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=dataset_service.ListDatasetVersionsRequest, +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_dataset_version + ] = mock_object + + request = {} + await client.get_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_dataset_version_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.GetDatasetVersionRequest, ): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -4451,49 +5484,55 @@ async def test_list_dataset_versions_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetVersionsResponse( - next_page_token="next_page_token_value", + dataset_version.DatasetVersion( + name="name_value", + etag="etag_value", + big_query_dataset_name="big_query_dataset_name_value", + display_name="display_name_value", ) ) - response = await client.list_dataset_versions(request) + response = await client.get_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.ListDatasetVersionsRequest() + request = dataset_service.GetDatasetVersionRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDatasetVersionsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, dataset_version.DatasetVersion) + assert response.name == "name_value" + assert response.etag == "etag_value" + assert response.big_query_dataset_name == "big_query_dataset_name_value" + assert response.display_name == "display_name_value" @pytest.mark.asyncio -async def test_list_dataset_versions_async_from_dict(): - await test_list_dataset_versions_async(request_type=dict) +async def test_get_dataset_version_async_from_dict(): + await test_get_dataset_version_async(request_type=dict) -def test_list_dataset_versions_field_headers(): +def test_get_dataset_version_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.ListDatasetVersionsRequest() + request = dataset_service.GetDatasetVersionRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: - call.return_value = dataset_service.ListDatasetVersionsResponse() - client.list_dataset_versions(request) + call.return_value = dataset_version.DatasetVersion() + client.get_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4504,30 +5543,30 @@ def test_list_dataset_versions_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_list_dataset_versions_field_headers_async(): +async def test_get_dataset_version_field_headers_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.ListDatasetVersionsRequest() + request = dataset_service.GetDatasetVersionRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetVersionsResponse() + dataset_version.DatasetVersion() ) - await client.list_dataset_versions(request) + await client.get_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4538,37 +5577,37 @@ async def test_list_dataset_versions_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] -def test_list_dataset_versions_flattened(): +def test_get_dataset_version_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetVersionsResponse() + call.return_value = dataset_version.DatasetVersion() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_dataset_versions( - parent="parent_value", + client.get_dataset_version( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val -def test_list_dataset_versions_flattened_error(): +def test_get_dataset_version_flattened_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4576,45 +5615,45 @@ def test_list_dataset_versions_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_dataset_versions( - dataset_service.ListDatasetVersionsRequest(), - parent="parent_value", + client.get_dataset_version( + dataset_service.GetDatasetVersionRequest(), + name="name_value", ) @pytest.mark.asyncio -async def test_list_dataset_versions_flattened_async(): +async def test_get_dataset_version_flattened_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), "__call__" + type(client.transport.get_dataset_version), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetVersionsResponse() + call.return_value = dataset_version.DatasetVersion() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetVersionsResponse() + dataset_version.DatasetVersion() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_dataset_versions( - parent="parent_value", + response = await client.get_dataset_version( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val @pytest.mark.asyncio -async def test_list_dataset_versions_flattened_error_async(): +async def test_get_dataset_version_flattened_error_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4622,317 +5661,222 @@ async def test_list_dataset_versions_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.list_dataset_versions( - dataset_service.ListDatasetVersionsRequest(), - parent="parent_value", + await client.get_dataset_version( + dataset_service.GetDatasetVersionRequest(), + name="name_value", ) -def test_list_dataset_versions_pager(transport_name: str = "grpc"): +@pytest.mark.parametrize( + "request_type", + [ + dataset_service.ListDatasetVersionsRequest, + dict, + ], +) +def test_list_dataset_versions(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport=transport, ) + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( type(client.transport.list_dataset_versions), "__call__" ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - ], - next_page_token="abc", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[], - next_page_token="def", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - ], - next_page_token="ghi", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - ], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetVersionsResponse( + next_page_token="next_page_token_value", ) - pager = client.list_dataset_versions(request={}) + response = client.list_dataset_versions(request) - assert pager._metadata == metadata + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = dataset_service.ListDatasetVersionsRequest() + assert args[0] == request - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, dataset_version.DatasetVersion) for i in results) + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatasetVersionsPager) + assert response.next_page_token == "next_page_token_value" -def test_list_dataset_versions_pages(transport_name: str = "grpc"): +def test_list_dataset_versions_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( type(client.transport.list_dataset_versions), "__call__" ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - ], - next_page_token="abc", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[], - next_page_token="def", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - ], - next_page_token="ghi", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - ], - ), - RuntimeError, + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - pages = list(client.list_dataset_versions(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + client.list_dataset_versions() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.ListDatasetVersionsRequest() -@pytest.mark.asyncio -async def test_list_dataset_versions_async_pager(): - client = DatasetServiceAsyncClient( +def test_list_dataset_versions_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = dataset_service.ListDatasetVersionsRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_dataset_versions), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_dataset_versions), "__call__" ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - ], - next_page_token="abc", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[], - next_page_token="def", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - ], - next_page_token="ghi", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - ], - ), - RuntimeError, + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - async_pager = await client.list_dataset_versions( - request={}, + client.list_dataset_versions(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.ListDatasetVersionsRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", ) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: # pragma: no branch - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, dataset_version.DatasetVersion) for i in responses) -@pytest.mark.asyncio -async def test_list_dataset_versions_async_pages(): - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_dataset_versions), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - ], - next_page_token="abc", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[], - next_page_token="def", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - ], - next_page_token="ghi", - ), - dataset_service.ListDatasetVersionsResponse( - dataset_versions=[ - dataset_version.DatasetVersion(), - dataset_version.DatasetVersion(), - ], - ), - RuntimeError, +def test_list_dataset_versions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", ) - pages = [] - # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` - # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch - await client.list_dataset_versions(request={}) - ).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() -@pytest.mark.parametrize( - "request_type", - [ - dataset_service.RestoreDatasetVersionRequest, - dict, - ], -) -def test_restore_dataset_version(request_type, transport: str = "grpc"): - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() + # Ensure method has been cached + assert ( + client._transport.list_dataset_versions + in client._transport._wrapped_methods + ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.restore_dataset_version(request) + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_dataset_versions + ] = mock_rpc + request = {} + client.list_dataset_versions(request) # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - request = dataset_service.RestoreDatasetVersionRequest() - assert args[0] == request + assert mock_rpc.call_count == 1 - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + client.list_dataset_versions(request) + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 -def test_restore_dataset_version_empty_call(): + +@pytest.mark.asyncio +async def test_list_dataset_versions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. - client = DatasetServiceClient( + client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", + transport="grpc_asyncio", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" + type(client.transport.list_dataset_versions), "__call__" ) as call: - client.restore_dataset_version() + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetVersionsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_dataset_versions() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.RestoreDatasetVersionRequest() + assert args[0] == dataset_service.ListDatasetVersionsRequest() -def test_restore_dataset_version_non_empty_request_with_auto_populated_field(): - # This test is a coverage failsafe to make sure that UUID4 fields are - # automatically populated, according to AIP-4235, with non-empty requests. - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) +@pytest.mark.asyncio +async def test_list_dataset_versions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Populate all string fields in the request which are not UUID4 - # since we want to check that UUID4 are populated automatically - # if they meet the requirements of AIP 4235. - request = dataset_service.RestoreDatasetVersionRequest( - name="name_value", - ) + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" - ) as call: - client.restore_dataset_version(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.RestoreDatasetVersionRequest( - name="name_value", + # Ensure method has been cached + assert ( + client._client._transport.list_dataset_versions + in client._client._transport._wrapped_methods ) + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) -@pytest.mark.asyncio -async def test_restore_dataset_version_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_dataset_versions + ] = mock_object - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - response = await client.restore_dataset_version() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.RestoreDatasetVersionRequest() + request = {} + await client.list_dataset_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_dataset_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio -async def test_restore_dataset_version_async( +async def test_list_dataset_versions_async( transport: str = "grpc_asyncio", - request_type=dataset_service.RestoreDatasetVersionRequest, + request_type=dataset_service.ListDatasetVersionsRequest, ): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -4945,46 +5889,49 @@ async def test_restore_dataset_version_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" + type(client.transport.list_dataset_versions), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + dataset_service.ListDatasetVersionsResponse( + next_page_token="next_page_token_value", + ) ) - response = await client.restore_dataset_version(request) + response = await client.list_dataset_versions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.RestoreDatasetVersionRequest() + request = dataset_service.ListDatasetVersionsRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, pagers.ListDatasetVersionsAsyncPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio -async def test_restore_dataset_version_async_from_dict(): - await test_restore_dataset_version_async(request_type=dict) +async def test_list_dataset_versions_async_from_dict(): + await test_list_dataset_versions_async(request_type=dict) -def test_restore_dataset_version_field_headers(): +def test_list_dataset_versions_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.RestoreDatasetVersionRequest() + request = dataset_service.ListDatasetVersionsRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" + type(client.transport.list_dataset_versions), "__call__" ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - client.restore_dataset_version(request) + call.return_value = dataset_service.ListDatasetVersionsResponse() + client.list_dataset_versions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4995,30 +5942,30 @@ def test_restore_dataset_version_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_restore_dataset_version_field_headers_async(): +async def test_list_dataset_versions_field_headers_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.RestoreDatasetVersionRequest() + request = dataset_service.ListDatasetVersionsRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" + type(client.transport.list_dataset_versions), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") + dataset_service.ListDatasetVersionsResponse() ) - await client.restore_dataset_version(request) + await client.list_dataset_versions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -5029,37 +5976,37 @@ async def test_restore_dataset_version_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", ) in kw["metadata"] -def test_restore_dataset_version_flattened(): +def test_list_dataset_versions_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" + type(client.transport.list_dataset_versions), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = dataset_service.ListDatasetVersionsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.restore_dataset_version( - name="name_value", + client.list_dataset_versions( + parent="parent_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].parent + mock_val = "parent_value" assert arg == mock_val -def test_restore_dataset_version_flattened_error(): +def test_list_dataset_versions_flattened_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -5067,45 +6014,45 @@ def test_restore_dataset_version_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.restore_dataset_version( - dataset_service.RestoreDatasetVersionRequest(), - name="name_value", + client.list_dataset_versions( + dataset_service.ListDatasetVersionsRequest(), + parent="parent_value", ) @pytest.mark.asyncio -async def test_restore_dataset_version_flattened_async(): +async def test_list_dataset_versions_flattened_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.restore_dataset_version), "__call__" + type(client.transport.list_dataset_versions), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = dataset_service.ListDatasetVersionsResponse() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + dataset_service.ListDatasetVersionsResponse() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.restore_dataset_version( - name="name_value", + response = await client.list_dataset_versions( + parent="parent_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].parent + mock_val = "parent_value" assert arg == mock_val @pytest.mark.asyncio -async def test_restore_dataset_version_flattened_error_async(): +async def test_list_dataset_versions_flattened_error_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -5113,20 +6060,218 @@ async def test_restore_dataset_version_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.restore_dataset_version( - dataset_service.RestoreDatasetVersionRequest(), - name="name_value", + await client.list_dataset_versions( + dataset_service.ListDatasetVersionsRequest(), + parent="parent_value", ) -@pytest.mark.parametrize( - "request_type", +def test_list_dataset_versions_pager(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_dataset_versions), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + ], + next_page_token="abc", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[], + next_page_token="def", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + ], + next_page_token="ghi", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_dataset_versions(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, dataset_version.DatasetVersion) for i in results) + + +def test_list_dataset_versions_pages(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_dataset_versions), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + ], + next_page_token="abc", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[], + next_page_token="def", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + ], + next_page_token="ghi", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + ], + ), + RuntimeError, + ) + pages = list(client.list_dataset_versions(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_dataset_versions_async_pager(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_dataset_versions), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + ], + next_page_token="abc", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[], + next_page_token="def", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + ], + next_page_token="ghi", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_dataset_versions( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, dataset_version.DatasetVersion) for i in responses) + + +@pytest.mark.asyncio +async def test_list_dataset_versions_async_pages(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_dataset_versions), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + ], + next_page_token="abc", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[], + next_page_token="def", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + ], + next_page_token="ghi", + ), + dataset_service.ListDatasetVersionsResponse( + dataset_versions=[ + dataset_version.DatasetVersion(), + dataset_version.DatasetVersion(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_dataset_versions(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", [ - dataset_service.ListDataItemsRequest, + dataset_service.RestoreDatasetVersionRequest, dict, ], ) -def test_list_data_items(request_type, transport: str = "grpc"): +def test_restore_dataset_version(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5137,25 +6282,24 @@ def test_list_data_items(request_type, transport: str = "grpc"): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) - response = client.list_data_items(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.restore_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = dataset_service.ListDataItemsRequest() + request = dataset_service.RestoreDatasetVersionRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, future.Future) -def test_list_data_items_empty_call(): +def test_restore_dataset_version_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( @@ -5164,14 +6308,19 @@ def test_list_data_items_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: - client.list_data_items() + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.restore_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListDataItemsRequest() + assert args[0] == dataset_service.RestoreDatasetVersionRequest() -def test_list_data_items_non_empty_request_with_auto_populated_field(): +def test_restore_dataset_version_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = DatasetServiceClient( @@ -5182,29 +6331,72 @@ def test_list_data_items_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = dataset_service.ListDataItemsRequest( - parent="parent_value", - filter="filter_value", - page_token="page_token_value", - order_by="order_by_value", + request = dataset_service.RestoreDatasetVersionRequest( + name="name_value", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: - client.list_data_items(request=request) + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.restore_dataset_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListDataItemsRequest( - parent="parent_value", - filter="filter_value", - page_token="page_token_value", - order_by="order_by_value", + assert args[0] == dataset_service.RestoreDatasetVersionRequest( + name="name_value", ) -@pytest.mark.asyncio -async def test_list_data_items_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, +def test_restore_dataset_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.restore_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.restore_dataset_version + ] = mock_rpc + request = {} + client.restore_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.restore_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_restore_dataset_version_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5212,22 +6404,73 @@ async def test_list_data_items_empty_call_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.list_data_items() + response = await client.restore_dataset_version() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListDataItemsRequest() + assert args[0] == dataset_service.RestoreDatasetVersionRequest() @pytest.mark.asyncio -async def test_list_data_items_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest +async def test_restore_dataset_version_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.restore_dataset_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.restore_dataset_version + ] = mock_object + + request = {} + await client.restore_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.restore_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_restore_dataset_version_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.RestoreDatasetVersionRequest, ): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5239,46 +6482,47 @@ async def test_list_data_items_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.list_data_items(request) + response = await client.restore_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.ListDataItemsRequest() + request = dataset_service.RestoreDatasetVersionRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataItemsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_list_data_items_async_from_dict(): - await test_list_data_items_async(request_type=dict) +async def test_restore_dataset_version_async_from_dict(): + await test_restore_dataset_version_async(request_type=dict) -def test_list_data_items_field_headers(): +def test_restore_dataset_version_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.ListDataItemsRequest() + request = dataset_service.RestoreDatasetVersionRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: - call.return_value = dataset_service.ListDataItemsResponse() - client.list_data_items(request) + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.restore_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -5289,28 +6533,30 @@ def test_list_data_items_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_list_data_items_field_headers_async(): +async def test_restore_dataset_version_field_headers_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.ListDataItemsRequest() + request = dataset_service.RestoreDatasetVersionRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() + operations_pb2.Operation(name="operations/op") ) - await client.list_data_items(request) + await client.restore_dataset_version(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -5321,35 +6567,37 @@ async def test_list_data_items_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] -def test_list_data_items_flattened(): +def test_restore_dataset_version_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_items( - parent="parent_value", + client.restore_dataset_version( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val -def test_list_data_items_flattened_error(): +def test_restore_dataset_version_flattened_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -5357,43 +6605,45 @@ def test_list_data_items_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent="parent_value", + client.restore_dataset_version( + dataset_service.RestoreDatasetVersionRequest(), + name="name_value", ) @pytest.mark.asyncio -async def test_list_data_items_flattened_async(): +async def test_restore_dataset_version_flattened_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.restore_dataset_version), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse() + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_items( - parent="parent_value", + response = await client.restore_dataset_version( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val @pytest.mark.asyncio -async def test_list_data_items_flattened_error_async(): +async def test_restore_dataset_version_flattened_error_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -5401,211 +6651,210 @@ async def test_list_data_items_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent="parent_value", + await client.restore_dataset_version( + dataset_service.RestoreDatasetVersionRequest(), + name="name_value", ) -def test_list_data_items_pager(transport_name: str = "grpc"): +@pytest.mark.parametrize( + "request_type", + [ + dataset_service.ListDataItemsRequest, + dict, + ], +) +def test_list_data_items(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport=transport, ) + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], - ), - RuntimeError, + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", ) + response = client.list_data_items(request) - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_data_items(request={}) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = dataset_service.ListDataItemsRequest() + assert args[0] == request - assert pager._metadata == metadata + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataItemsPager) + assert response.next_page_token == "next_page_token_value" - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) for i in results) +def test_list_data_items_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) -def test_list_data_items_pages(transport_name: str = "grpc"): + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_data_items() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.ListDataItemsRequest() + + +def test_list_data_items_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = dataset_service.ListDataItemsRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], - ), - RuntimeError, + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_data_items(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.ListDataItemsRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", ) - pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - -@pytest.mark.asyncio -async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], - ), - RuntimeError, +def test_list_data_items_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", ) - async_pager = await client.list_data_items( - request={}, + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_data_items in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: # pragma: no branch - responses.append(response) + client._transport._wrapped_methods[client._transport.list_data_items] = mock_rpc + request = {} + client.list_data_items(request) - assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) for i in responses) + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 @pytest.mark.asyncio -async def test_list_data_items_async_pages(): +async def test_list_data_items_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token="ghi", - ), + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], - ), - RuntimeError, + next_page_token="next_page_token_value", + ) ) - pages = [] - # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` - # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch - await client.list_data_items(request={}) - ).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + response = await client.list_data_items() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.ListDataItemsRequest() -@pytest.mark.parametrize( - "request_type", - [ - dataset_service.SearchDataItemsRequest, - dict, - ], -) -def test_search_data_items(request_type, transport: str = "grpc"): - client = DatasetServiceClient( +@pytest.mark.asyncio +async def test_list_data_items_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_data_items + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_data_items + ] = mock_object + + request = {} + await client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_data_items_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest +): + client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -5615,250 +6864,206 @@ def test_search_data_items(request_type, transport: str = "grpc"): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.SearchDataItemsResponse( - next_page_token="next_page_token_value", + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) ) - response = client.search_data_items(request) + response = await client.list_data_items(request) # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 + assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.SearchDataItemsRequest() + request = dataset_service.ListDataItemsRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.SearchDataItemsPager) + assert isinstance(response, pagers.ListDataItemsAsyncPager) assert response.next_page_token == "next_page_token_value" -def test_search_data_items_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_data_items_async_from_dict(): + await test_list_data_items_async(request_type=dict) + + +def test_list_data_items_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", ) + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest() + + request.parent = "parent_value" + # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: - client.search_data_items() - call.assert_called() + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = dataset_service.ListDataItemsResponse() + client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.SearchDataItemsRequest() + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] -def test_search_data_items_non_empty_request_with_auto_populated_field(): - # This test is a coverage failsafe to make sure that UUID4 fields are - # automatically populated, according to AIP-4235, with non-empty requests. - client = DatasetServiceClient( +@pytest.mark.asyncio +async def test_list_data_items_field_headers_async(): + client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", ) - # Populate all string fields in the request which are not UUID4 - # since we want to check that UUID4 are populated automatically - # if they meet the requirements of AIP 4235. - request = dataset_service.SearchDataItemsRequest( - order_by_data_item="order_by_data_item_value", - dataset="dataset_value", - saved_query="saved_query_value", - data_labeling_job="data_labeling_job_value", - data_item_filter="data_item_filter_value", - annotations_filter="annotations_filter_value", - order_by="order_by_value", - page_token="page_token_value", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: - client.search_data_items(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.SearchDataItemsRequest( - order_by_data_item="order_by_data_item_value", - dataset="dataset_value", - saved_query="saved_query_value", - data_labeling_job="data_labeling_job_value", - data_item_filter="data_item_filter_value", - annotations_filter="annotations_filter_value", - order_by="order_by_value", - page_token="page_token_value", - ) - - -@pytest.mark.asyncio -async def test_search_data_items_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.SearchDataItemsResponse( - next_page_token="next_page_token_value", - ) - ) - response = await client.search_data_items() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.SearchDataItemsRequest() - - -@pytest.mark.asyncio -async def test_search_data_items_async( - transport: str = "grpc_asyncio", request_type=dataset_service.SearchDataItemsRequest -): - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest() - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: - # Designate an appropriate return value for the call. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.SearchDataItemsResponse( - next_page_token="next_page_token_value", - ) + dataset_service.ListDataItemsResponse() ) - response = await client.search_data_items(request) + await client.list_data_items(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.SearchDataItemsRequest() assert args[0] == request - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.SearchDataItemsAsyncPager) - assert response.next_page_token == "next_page_token_value" - - -@pytest.mark.asyncio -async def test_search_data_items_async_from_dict(): - await test_search_data_items_async(request_type=dict) + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] -def test_search_data_items_field_headers(): +def test_list_data_items_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.SearchDataItemsRequest() - - request.dataset = "dataset_value" - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: - call.return_value = dataset_service.SearchDataItemsResponse() - client.search_data_items(request) + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_data_items( + parent="parent_value", + ) - # Establish that the underlying gRPC stub method was called. + # Establish that the underlying call was made with the expected + # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == request + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "dataset=dataset_value", - ) in kw["metadata"] + +def test_list_data_items_flattened_error(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_items( + dataset_service.ListDataItemsRequest(), + parent="parent_value", + ) @pytest.mark.asyncio -async def test_search_data_items_field_headers_async(): +async def test_list_data_items_flattened_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.SearchDataItemsRequest() - - request.dataset = "dataset_value" - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.SearchDataItemsResponse() + dataset_service.ListDataItemsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_data_items( + parent="parent_value", ) - await client.search_data_items(request) - # Establish that the underlying gRPC stub method was called. + # Establish that the underlying call was made with the expected + # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "dataset=dataset_value", - ) in kw["metadata"] +@pytest.mark.asyncio +async def test_list_data_items_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) -def test_search_data_items_pager(transport_name: str = "grpc"): + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_data_items( + dataset_service.ListDataItemsRequest(), + parent="parent_value", + ) + + +def test_list_data_items_pager(transport_name: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), - dataset_service.DataItemView(), - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), ], next_page_token="abc", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[], + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), ], next_page_token="ghi", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), ], ), RuntimeError, @@ -5866,101 +7071,97 @@ def test_search_data_items_pager(transport_name: str = "grpc"): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("dataset", ""),)), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.search_data_items(request={}) + pager = client.list_data_items(request={}) assert pager._metadata == metadata results = list(pager) assert len(results) == 6 - assert all(isinstance(i, dataset_service.DataItemView) for i in results) + assert all(isinstance(i, data_item.DataItem) for i in results) -def test_search_data_items_pages(transport_name: str = "grpc"): +def test_list_data_items_pages(transport_name: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.search_data_items), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), - dataset_service.DataItemView(), - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), ], next_page_token="abc", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[], + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), ], next_page_token="ghi", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), - dataset_service.DataItemView(), - ], - ), - RuntimeError, + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], + ), + RuntimeError, ) - pages = list(client.search_data_items(request={}).pages) + pages = list(client.list_data_items(request={}).pages) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @pytest.mark.asyncio -async def test_search_data_items_async_pager(): +async def test_list_data_items_async_pager(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_data_items), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), - dataset_service.DataItemView(), - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), ], next_page_token="abc", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[], + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), ], next_page_token="ghi", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), ], ), RuntimeError, ) - async_pager = await client.search_data_items( + async_pager = await client.list_data_items( request={}, ) assert async_pager.next_page_token == "abc" @@ -5969,45 +7170,43 @@ async def test_search_data_items_async_pager(): responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, dataset_service.DataItemView) for i in responses) + assert all(isinstance(i, data_item.DataItem) for i in responses) @pytest.mark.asyncio -async def test_search_data_items_async_pages(): +async def test_list_data_items_async_pages(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_data_items), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), - dataset_service.DataItemView(), - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), ], next_page_token="abc", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[], + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), ], next_page_token="ghi", ), - dataset_service.SearchDataItemsResponse( - data_item_views=[ - dataset_service.DataItemView(), - dataset_service.DataItemView(), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), ], ), RuntimeError, @@ -6016,7 +7215,7 @@ async def test_search_data_items_async_pages(): # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 async for page_ in ( # pragma: no branch - await client.search_data_items(request={}) + await client.list_data_items(request={}) ).pages: pages.append(page_) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): @@ -6026,11 +7225,11 @@ async def test_search_data_items_async_pages(): @pytest.mark.parametrize( "request_type", [ - dataset_service.ListSavedQueriesRequest, + dataset_service.SearchDataItemsRequest, dict, ], ) -def test_list_saved_queries(request_type, transport: str = "grpc"): +def test_search_data_items(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6042,26 +7241,26 @@ def test_list_saved_queries(request_type, transport: str = "grpc"): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListSavedQueriesResponse( + call.return_value = dataset_service.SearchDataItemsResponse( next_page_token="next_page_token_value", ) - response = client.list_saved_queries(request) + response = client.search_data_items(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = dataset_service.ListSavedQueriesRequest() + request = dataset_service.SearchDataItemsRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListSavedQueriesPager) + assert isinstance(response, pagers.SearchDataItemsPager) assert response.next_page_token == "next_page_token_value" -def test_list_saved_queries_empty_call(): +def test_search_data_items_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( @@ -6071,15 +7270,18 @@ def test_list_saved_queries_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: - client.list_saved_queries() + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.search_data_items() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListSavedQueriesRequest() + assert args[0] == dataset_service.SearchDataItemsRequest() -def test_list_saved_queries_non_empty_request_with_auto_populated_field(): +def test_search_data_items_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = DatasetServiceClient( @@ -6090,30 +7292,78 @@ def test_list_saved_queries_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = dataset_service.ListSavedQueriesRequest( - parent="parent_value", - filter="filter_value", - page_token="page_token_value", + request = dataset_service.SearchDataItemsRequest( + order_by_data_item="order_by_data_item_value", + dataset="dataset_value", + saved_query="saved_query_value", + data_labeling_job="data_labeling_job_value", + data_item_filter="data_item_filter_value", + annotations_filter="annotations_filter_value", order_by="order_by_value", + page_token="page_token_value", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: - client.list_saved_queries(request=request) + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.search_data_items(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListSavedQueriesRequest( - parent="parent_value", - filter="filter_value", - page_token="page_token_value", + assert args[0] == dataset_service.SearchDataItemsRequest( + order_by_data_item="order_by_data_item_value", + dataset="dataset_value", + saved_query="saved_query_value", + data_labeling_job="data_labeling_job_value", + data_item_filter="data_item_filter_value", + annotations_filter="annotations_filter_value", order_by="order_by_value", + page_token="page_token_value", + ) + + +def test_search_data_items_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.search_data_items in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) + client._transport._wrapped_methods[ + client._transport.search_data_items + ] = mock_rpc + request = {} + client.search_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 @pytest.mark.asyncio -async def test_list_saved_queries_empty_call_async(): +async def test_search_data_items_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceAsyncClient( @@ -6123,24 +7373,69 @@ async def test_list_saved_queries_empty_call_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListSavedQueriesResponse( + dataset_service.SearchDataItemsResponse( next_page_token="next_page_token_value", ) ) - response = await client.list_saved_queries() + response = await client.search_data_items() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListSavedQueriesRequest() + assert args[0] == dataset_service.SearchDataItemsRequest() @pytest.mark.asyncio -async def test_list_saved_queries_async( +async def test_search_data_items_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=dataset_service.ListSavedQueriesRequest, +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_data_items + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_data_items + ] = mock_object + + request = {} + await client.search_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_search_data_items_async( + transport: str = "grpc_asyncio", request_type=dataset_service.SearchDataItemsRequest ): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6153,49 +7448,49 @@ async def test_list_saved_queries_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListSavedQueriesResponse( + dataset_service.SearchDataItemsResponse( next_page_token="next_page_token_value", ) ) - response = await client.list_saved_queries(request) + response = await client.search_data_items(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.ListSavedQueriesRequest() + request = dataset_service.SearchDataItemsRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListSavedQueriesAsyncPager) + assert isinstance(response, pagers.SearchDataItemsAsyncPager) assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio -async def test_list_saved_queries_async_from_dict(): - await test_list_saved_queries_async(request_type=dict) +async def test_search_data_items_async_from_dict(): + await test_search_data_items_async(request_type=dict) -def test_list_saved_queries_field_headers(): +def test_search_data_items_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.ListSavedQueriesRequest() + request = dataset_service.SearchDataItemsRequest() - request.parent = "parent_value" + request.dataset = "dataset_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: - call.return_value = dataset_service.ListSavedQueriesResponse() - client.list_saved_queries(request) + call.return_value = dataset_service.SearchDataItemsResponse() + client.search_data_items(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -6206,30 +7501,30 @@ def test_list_saved_queries_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "dataset=dataset_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_list_saved_queries_field_headers_async(): +async def test_search_data_items_field_headers_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.ListSavedQueriesRequest() + request = dataset_service.SearchDataItemsRequest() - request.parent = "parent_value" + request.dataset = "dataset_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListSavedQueriesResponse() + dataset_service.SearchDataItemsResponse() ) - await client.list_saved_queries(request) + await client.search_data_items(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -6240,130 +7535,44 @@ async def test_list_saved_queries_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "dataset=dataset_value", ) in kw["metadata"] -def test_list_saved_queries_flattened(): +def test_search_data_items_pager(transport_name: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListSavedQueriesResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_saved_queries( - parent="parent_value", - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" - assert arg == mock_val - - -def test_list_saved_queries_flattened_error(): - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_saved_queries( - dataset_service.ListSavedQueriesRequest(), - parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_saved_queries_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListSavedQueriesResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListSavedQueriesResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_saved_queries( - parent="parent_value", - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" - assert arg == mock_val - - -@pytest.mark.asyncio -async def test_list_saved_queries_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_saved_queries( - dataset_service.ListSavedQueriesRequest(), - parent="parent_value", - ) - - -def test_list_saved_queries_pager(transport_name: str = "grpc"): - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), - saved_query.SavedQuery(), - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), + dataset_service.DataItemView(), + dataset_service.DataItemView(), ], next_page_token="abc", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[], + dataset_service.SearchDataItemsResponse( + data_item_views=[], next_page_token="def", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), ], next_page_token="ghi", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), + dataset_service.DataItemView(), ], ), RuntimeError, @@ -6371,18 +7580,18 @@ def test_list_saved_queries_pager(transport_name: str = "grpc"): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata((("dataset", ""),)), ) - pager = client.list_saved_queries(request={}) + pager = client.search_data_items(request={}) assert pager._metadata == metadata results = list(pager) assert len(results) == 6 - assert all(isinstance(i, saved_query.SavedQuery) for i in results) + assert all(isinstance(i, dataset_service.DataItemView) for i in results) -def test_list_saved_queries_pages(transport_name: str = "grpc"): +def test_search_data_items_pages(transport_name: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport_name, @@ -6390,82 +7599,82 @@ def test_list_saved_queries_pages(transport_name: str = "grpc"): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), "__call__" + type(client.transport.search_data_items), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), - saved_query.SavedQuery(), - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), + dataset_service.DataItemView(), + dataset_service.DataItemView(), ], next_page_token="abc", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[], + dataset_service.SearchDataItemsResponse( + data_item_views=[], next_page_token="def", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), ], next_page_token="ghi", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), + dataset_service.DataItemView(), ], ), RuntimeError, ) - pages = list(client.list_saved_queries(request={}).pages) + pages = list(client.search_data_items(request={}).pages) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @pytest.mark.asyncio -async def test_list_saved_queries_async_pager(): +async def test_search_data_items_async_pager(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), + type(client.transport.search_data_items), "__call__", new_callable=mock.AsyncMock, ) as call: # Set the response to a series of pages. call.side_effect = ( - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), - saved_query.SavedQuery(), - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), + dataset_service.DataItemView(), + dataset_service.DataItemView(), ], next_page_token="abc", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[], + dataset_service.SearchDataItemsResponse( + data_item_views=[], next_page_token="def", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), ], next_page_token="ghi", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), + dataset_service.DataItemView(), ], ), RuntimeError, ) - async_pager = await client.list_saved_queries( + async_pager = await client.search_data_items( request={}, ) assert async_pager.next_page_token == "abc" @@ -6474,45 +7683,45 @@ async def test_list_saved_queries_async_pager(): responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, saved_query.SavedQuery) for i in responses) + assert all(isinstance(i, dataset_service.DataItemView) for i in responses) @pytest.mark.asyncio -async def test_list_saved_queries_async_pages(): +async def test_search_data_items_async_pages(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_saved_queries), + type(client.transport.search_data_items), "__call__", new_callable=mock.AsyncMock, ) as call: # Set the response to a series of pages. call.side_effect = ( - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), - saved_query.SavedQuery(), - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), + dataset_service.DataItemView(), + dataset_service.DataItemView(), ], next_page_token="abc", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[], + dataset_service.SearchDataItemsResponse( + data_item_views=[], next_page_token="def", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), ], next_page_token="ghi", ), - dataset_service.ListSavedQueriesResponse( - saved_queries=[ - saved_query.SavedQuery(), - saved_query.SavedQuery(), + dataset_service.SearchDataItemsResponse( + data_item_views=[ + dataset_service.DataItemView(), + dataset_service.DataItemView(), ], ), RuntimeError, @@ -6521,7 +7730,7 @@ async def test_list_saved_queries_async_pages(): # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 async for page_ in ( # pragma: no branch - await client.list_saved_queries(request={}) + await client.search_data_items(request={}) ).pages: pages.append(page_) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): @@ -6531,11 +7740,11 @@ async def test_list_saved_queries_async_pages(): @pytest.mark.parametrize( "request_type", [ - dataset_service.DeleteSavedQueryRequest, + dataset_service.ListSavedQueriesRequest, dict, ], ) -def test_delete_saved_query(request_type, transport: str = "grpc"): +def test_list_saved_queries(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6547,23 +7756,26 @@ def test_delete_saved_query(request_type, transport: str = "grpc"): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.delete_saved_query(request) + call.return_value = dataset_service.ListSavedQueriesResponse( + next_page_token="next_page_token_value", + ) + response = client.list_saved_queries(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - request = dataset_service.DeleteSavedQueryRequest() + request = dataset_service.ListSavedQueriesRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, pagers.ListSavedQueriesPager) + assert response.next_page_token == "next_page_token_value" -def test_delete_saved_query_empty_call(): +def test_list_saved_queries_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( @@ -6573,15 +7785,18 @@ def test_delete_saved_query_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: - client.delete_saved_query() + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_saved_queries() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.DeleteSavedQueryRequest() + assert args[0] == dataset_service.ListSavedQueriesRequest() -def test_delete_saved_query_non_empty_request_with_auto_populated_field(): +def test_list_saved_queries_non_empty_request_with_auto_populated_field(): # This test is a coverage failsafe to make sure that UUID4 fields are # automatically populated, according to AIP-4235, with non-empty requests. client = DatasetServiceClient( @@ -6592,49 +7807,145 @@ def test_delete_saved_query_non_empty_request_with_auto_populated_field(): # Populate all string fields in the request which are not UUID4 # since we want to check that UUID4 are populated automatically # if they meet the requirements of AIP 4235. - request = dataset_service.DeleteSavedQueryRequest( - name="name_value", + request = dataset_service.ListSavedQueriesRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: - client.delete_saved_query(request=request) + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_saved_queries(request=request) call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.DeleteSavedQueryRequest( - name="name_value", + assert args[0] == dataset_service.ListSavedQueriesRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", ) -@pytest.mark.asyncio -async def test_delete_saved_query_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", +def test_list_saved_queries_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_saved_queries in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_saved_queries + ] = mock_rpc + request = {} + client.list_saved_queries(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_saved_queries(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_saved_queries_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + dataset_service.ListSavedQueriesResponse( + next_page_token="next_page_token_value", + ) ) - response = await client.delete_saved_query() + response = await client.list_saved_queries() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.DeleteSavedQueryRequest() + assert args[0] == dataset_service.ListSavedQueriesRequest() @pytest.mark.asyncio -async def test_delete_saved_query_async( +async def test_list_saved_queries_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=dataset_service.DeleteSavedQueryRequest, +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_saved_queries + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_saved_queries + ] = mock_object + + request = {} + await client.list_saved_queries(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_saved_queries(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_saved_queries_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.ListSavedQueriesRequest, ): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6647,46 +7958,49 @@ async def test_delete_saved_query_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + dataset_service.ListSavedQueriesResponse( + next_page_token="next_page_token_value", + ) ) - response = await client.delete_saved_query(request) + response = await client.list_saved_queries(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.DeleteSavedQueryRequest() + request = dataset_service.ListSavedQueriesRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, pagers.ListSavedQueriesAsyncPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio -async def test_delete_saved_query_async_from_dict(): - await test_delete_saved_query_async(request_type=dict) +async def test_list_saved_queries_async_from_dict(): + await test_list_saved_queries_async(request_type=dict) -def test_delete_saved_query_field_headers(): +def test_list_saved_queries_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.DeleteSavedQueryRequest() + request = dataset_service.ListSavedQueriesRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - client.delete_saved_query(request) + call.return_value = dataset_service.ListSavedQueriesResponse() + client.list_saved_queries(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -6697,30 +8011,30 @@ def test_delete_saved_query_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_delete_saved_query_field_headers_async(): +async def test_list_saved_queries_field_headers_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.DeleteSavedQueryRequest() + request = dataset_service.ListSavedQueriesRequest() - request.name = "name_value" + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") + dataset_service.ListSavedQueriesResponse() ) - await client.delete_saved_query(request) + await client.list_saved_queries(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -6731,37 +8045,37 @@ async def test_delete_saved_query_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "name=name_value", + "parent=parent_value", ) in kw["metadata"] -def test_delete_saved_query_flattened(): +def test_list_saved_queries_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = dataset_service.ListSavedQueriesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_saved_query( - name="name_value", + client.list_saved_queries( + parent="parent_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].parent + mock_val = "parent_value" assert arg == mock_val -def test_delete_saved_query_flattened_error(): +def test_list_saved_queries_flattened_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -6769,45 +8083,45 @@ def test_delete_saved_query_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_saved_query( - dataset_service.DeleteSavedQueryRequest(), - name="name_value", + client.list_saved_queries( + dataset_service.ListSavedQueriesRequest(), + parent="parent_value", ) @pytest.mark.asyncio -async def test_delete_saved_query_flattened_async(): +async def test_list_saved_queries_flattened_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_saved_query), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = dataset_service.ListSavedQueriesResponse() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + dataset_service.ListSavedQueriesResponse() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_saved_query( - name="name_value", + response = await client.list_saved_queries( + parent="parent_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" + arg = args[0].parent + mock_val = "parent_value" assert arg == mock_val @pytest.mark.asyncio -async def test_delete_saved_query_flattened_error_async(): +async def test_list_saved_queries_flattened_error_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -6815,432 +8129,416 @@ async def test_delete_saved_query_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.delete_saved_query( - dataset_service.DeleteSavedQueryRequest(), - name="name_value", + await client.list_saved_queries( + dataset_service.ListSavedQueriesRequest(), + parent="parent_value", ) -@pytest.mark.parametrize( - "request_type", - [ - dataset_service.GetAnnotationSpecRequest, - dict, - ], -) -def test_get_annotation_spec(request_type, transport: str = "grpc"): +def test_list_saved_queries_pager(transport_name: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + transport=transport_name, ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec( - name="name_value", - display_name="display_name_value", - etag="etag_value", + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + saved_query.SavedQuery(), + saved_query.SavedQuery(), + ], + next_page_token="abc", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[], + next_page_token="def", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + ], + next_page_token="ghi", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + saved_query.SavedQuery(), + ], + ), + RuntimeError, ) - response = client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - request = dataset_service.GetAnnotationSpecRequest() - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_saved_queries(request={}) -def test_get_annotation_spec_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) + assert pager._metadata == metadata - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: - client.get_annotation_spec() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.GetAnnotationSpecRequest() + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, saved_query.SavedQuery) for i in results) -def test_get_annotation_spec_non_empty_request_with_auto_populated_field(): - # This test is a coverage failsafe to make sure that UUID4 fields are - # automatically populated, according to AIP-4235, with non-empty requests. +def test_list_saved_queries_pages(transport_name: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Populate all string fields in the request which are not UUID4 - # since we want to check that UUID4 are populated automatically - # if they meet the requirements of AIP 4235. - request = dataset_service.GetAnnotationSpecRequest( - name="name_value", + transport=transport_name, ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" + type(client.transport.list_saved_queries), "__call__" ) as call: - client.get_annotation_spec(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.GetAnnotationSpecRequest( - name="name_value", + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + saved_query.SavedQuery(), + saved_query.SavedQuery(), + ], + next_page_token="abc", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[], + next_page_token="def", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + ], + next_page_token="ghi", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + saved_query.SavedQuery(), + ], + ), + RuntimeError, ) + pages = list(client.list_saved_queries(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio -async def test_get_annotation_spec_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. +async def test_list_saved_queries_async_pager(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" + type(client.transport.list_saved_queries), + "__call__", + new_callable=mock.AsyncMock, ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec( - name="name_value", - display_name="display_name_value", - etag="etag_value", - ) + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + saved_query.SavedQuery(), + saved_query.SavedQuery(), + ], + next_page_token="abc", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[], + next_page_token="def", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + ], + next_page_token="ghi", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + saved_query.SavedQuery(), + ], + ), + RuntimeError, ) - response = await client.get_annotation_spec() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.GetAnnotationSpecRequest() + async_pager = await client.list_saved_queries( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, saved_query.SavedQuery) for i in responses) @pytest.mark.asyncio -async def test_get_annotation_spec_async( - transport: str = "grpc_asyncio", - request_type=dataset_service.GetAnnotationSpecRequest, -): +async def test_list_saved_queries_async_pages(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport, ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" + type(client.transport.list_saved_queries), + "__call__", + new_callable=mock.AsyncMock, ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec( - name="name_value", - display_name="display_name_value", - etag="etag_value", - ) + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + saved_query.SavedQuery(), + saved_query.SavedQuery(), + ], + next_page_token="abc", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[], + next_page_token="def", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + ], + next_page_token="ghi", + ), + dataset_service.ListSavedQueriesResponse( + saved_queries=[ + saved_query.SavedQuery(), + saved_query.SavedQuery(), + ], + ), + RuntimeError, ) - response = await client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - request = dataset_service.GetAnnotationSpecRequest() - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" - - -@pytest.mark.asyncio -async def test_get_annotation_spec_async_from_dict(): - await test_get_annotation_spec_async(request_type=dict) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_saved_queries(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token -def test_get_annotation_spec_field_headers(): +@pytest.mark.parametrize( + "request_type", + [ + dataset_service.DeleteSavedQueryRequest, + dict, + ], +) +def test_delete_saved_query(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetAnnotationSpecRequest() - - request.name = "name_value" + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" + type(client.transport.delete_saved_query), "__call__" ) as call: - call.return_value = annotation_spec.AnnotationSpec() - client.get_annotation_spec(request) + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_saved_query(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] + request = dataset_service.DeleteSavedQueryRequest() assert args[0] == request - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) -@pytest.mark.asyncio -async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient( +def test_delete_saved_query_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", ) - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetAnnotationSpecRequest() - - request.name = "name_value" - # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" + type(client.transport.delete_saved_query), "__call__" ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - await client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) + client.delete_saved_query() + call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + assert args[0] == dataset_service.DeleteSavedQueryRequest() -def test_get_annotation_spec_flattened(): +def test_delete_saved_query_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = dataset_service.DeleteSavedQueryRequest( + name="name_value", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" + type(client.transport.delete_saved_query), "__call__" ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_annotation_spec( + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_saved_query(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.DeleteSavedQueryRequest( name="name_value", ) - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" - assert arg == mock_val +def test_delete_saved_query_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) -def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name="name_value", + # Ensure method has been cached + assert ( + client._transport.delete_saved_query in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) + client._transport._wrapped_methods[ + client._transport.delete_saved_query + ] = mock_rpc + request = {} + client.delete_saved_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_saved_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 @pytest.mark.asyncio -async def test_get_annotation_spec_flattened_async(): +async def test_delete_saved_query_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" + type(client.transport.delete_saved_query), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_annotation_spec( - name="name_value", + operations_pb2.Operation(name="operations/spam") ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) + response = await client.delete_saved_query() + call.assert_called() _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = "name_value" - assert arg == mock_val + assert args[0] == dataset_service.DeleteSavedQueryRequest() @pytest.mark.asyncio -async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name="name_value", +async def test_delete_saved_query_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() -@pytest.mark.parametrize( - "request_type", - [ - dataset_service.ListAnnotationsRequest, - dict, - ], -) -def test_list_annotations(request_type, transport: str = "grpc"): - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", + # Ensure method has been cached + assert ( + client._client._transport.delete_saved_query + in client._client._transport._wrapped_methods ) - response = client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - request = dataset_service.ListAnnotationsRequest() - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_annotations_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - client.list_annotations() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListAnnotationsRequest() + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) -def test_list_annotations_non_empty_request_with_auto_populated_field(): - # This test is a coverage failsafe to make sure that UUID4 fields are - # automatically populated, according to AIP-4235, with non-empty requests. - client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_saved_query + ] = mock_object - # Populate all string fields in the request which are not UUID4 - # since we want to check that UUID4 are populated automatically - # if they meet the requirements of AIP 4235. - request = dataset_service.ListAnnotationsRequest( - parent="parent_value", - filter="filter_value", - page_token="page_token_value", - order_by="order_by_value", - ) + request = {} + await client.delete_saved_query(request) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - client.list_annotations(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListAnnotationsRequest( - parent="parent_value", - filter="filter_value", - page_token="page_token_value", - order_by="order_by_value", - ) + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() -@pytest.mark.asyncio -async def test_list_annotations_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = DatasetServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) + await client.delete_saved_query(request) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) - ) - response = await client.list_annotations() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == dataset_service.ListAnnotationsRequest() + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio -async def test_list_annotations_async( - transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest +async def test_delete_saved_query_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.DeleteSavedQueryRequest, ): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7252,46 +8550,47 @@ async def test_list_annotations_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_saved_query), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.list_annotations(request) + response = await client.delete_saved_query(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - request = dataset_service.ListAnnotationsRequest() + request = dataset_service.DeleteSavedQueryRequest() assert args[0] == request # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListAnnotationsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_list_annotations_async_from_dict(): - await test_list_annotations_async(request_type=dict) +async def test_delete_saved_query_async_from_dict(): + await test_delete_saved_query_async(request_type=dict) -def test_list_annotations_field_headers(): +def test_delete_saved_query_field_headers(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.ListAnnotationsRequest() + request = dataset_service.DeleteSavedQueryRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - call.return_value = dataset_service.ListAnnotationsResponse() - client.list_annotations(request) + with mock.patch.object( + type(client.transport.delete_saved_query), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_saved_query(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -7302,28 +8601,30 @@ def test_list_annotations_field_headers(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] @pytest.mark.asyncio -async def test_list_annotations_field_headers_async(): +async def test_delete_saved_query_field_headers_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = dataset_service.ListAnnotationsRequest() + request = dataset_service.DeleteSavedQueryRequest() - request.parent = "parent_value" + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_saved_query), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() + operations_pb2.Operation(name="operations/op") ) - await client.list_annotations(request) + await client.delete_saved_query(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -7334,35 +8635,37 @@ async def test_list_annotations_field_headers_async(): _, _, kw = call.mock_calls[0] assert ( "x-goog-request-params", - "parent=parent_value", + "name=name_value", ) in kw["metadata"] -def test_list_annotations_flattened(): +def test_delete_saved_query_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_saved_query), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_annotations( - parent="parent_value", + client.delete_saved_query( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val -def test_list_annotations_flattened_error(): +def test_delete_saved_query_flattened_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -7370,43 +8673,45 @@ def test_list_annotations_flattened_error(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent="parent_value", + client.delete_saved_query( + dataset_service.DeleteSavedQueryRequest(), + name="name_value", ) @pytest.mark.asyncio -async def test_list_annotations_flattened_async(): +async def test_delete_saved_query_flattened_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_saved_query), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse() + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_annotations( - parent="parent_value", + response = await client.delete_saved_query( + name="name_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - arg = args[0].parent - mock_val = "parent_value" + arg = args[0].name + mock_val = "name_value" assert arg == mock_val @pytest.mark.asyncio -async def test_list_annotations_flattened_error_async(): +async def test_delete_saved_query_flattened_error_async(): client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -7414,346 +8719,1488 @@ async def test_list_annotations_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent="parent_value", + await client.delete_saved_query( + dataset_service.DeleteSavedQueryRequest(), + name="name_value", ) -def test_list_annotations_pager(transport_name: str = "grpc"): +@pytest.mark.parametrize( + "request_type", + [ + dataset_service.GetAnnotationSpecRequest, + dict, + ], +) +def test_get_annotation_spec(request_type, transport: str = "grpc"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport=transport, ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], - ), - RuntimeError, - ) + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec( + name="name_value", + display_name="display_name_value", + etag="etag_value", ) - pager = client.list_annotations(request={}) + response = client.get_annotation_spec(request) - assert pager._metadata == metadata + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = dataset_service.GetAnnotationSpecRequest() + assert args[0] == request - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) for i in results) + # Establish that the response is the type that we expect. + assert isinstance(response, annotation_spec.AnnotationSpec) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.etag == "etag_value" -def test_list_annotations_pages(transport_name: str = "grpc"): +def test_get_annotation_spec_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport_name, + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], - ), - RuntimeError, + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + client.get_annotation_spec() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.GetAnnotationSpecRequest() -@pytest.mark.asyncio -async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient( +def test_get_annotation_spec_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = dataset_service.GetAnnotationSpecRequest( + name="name_value", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + type(client.transport.get_annotation_spec), "__call__" ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], - ), - RuntimeError, + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) - async_pager = await client.list_annotations( - request={}, + client.get_annotation_spec(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.GetAnnotationSpecRequest( + name="name_value", ) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: # pragma: no branch - responses.append(response) - assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) for i in responses) + +def test_get_annotation_spec_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_annotation_spec in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_annotation_spec + ] = mock_rpc + request = {} + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_annotation_spec(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 @pytest.mark.asyncio -async def test_list_annotations_async_pages(): +async def test_get_annotation_spec_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. client = DatasetServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + type(client.transport.get_annotation_spec), "__call__" ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], - ), - RuntimeError, + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec( + name="name_value", + display_name="display_name_value", + etag="etag_value", + ) ) - pages = [] - # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` - # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch - await client.list_annotations(request={}) - ).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + response = await client.get_annotation_spec() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.GetAnnotationSpecRequest() + + +@pytest.mark.asyncio +async def test_get_annotation_spec_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_annotation_spec + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_annotation_spec + ] = mock_object + + request = {} + await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_annotation_spec(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_annotation_spec_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.GetAnnotationSpecRequest, +): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec( + name="name_value", + display_name="display_name_value", + etag="etag_value", + ) + ) + response = await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = dataset_service.GetAnnotationSpecRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, annotation_spec.AnnotationSpec) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_annotation_spec_async_from_dict(): + await test_get_annotation_spec_async(request_type=dict) + + +def test_get_annotation_spec_field_headers(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = annotation_spec.AnnotationSpec() + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_annotation_spec_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) + await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_annotation_spec_flattened(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_annotation_spec( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_annotation_spec_flattened_error(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_annotation_spec_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_annotation_spec( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_annotation_spec_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + dataset_service.ListAnnotationsRequest, + dict, + ], +) +def test_list_annotations(request_type, transport: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + response = client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = dataset_service.ListAnnotationsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAnnotationsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_annotations_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_annotations() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.ListAnnotationsRequest() + + +def test_list_annotations_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = dataset_service.ListAnnotationsRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_annotations(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.ListAnnotationsRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", + ) + + +def test_list_annotations_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_annotations in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_annotations + ] = mock_rpc + request = {} + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_annotations_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_annotations() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == dataset_service.ListAnnotationsRequest() + + +@pytest.mark.asyncio +async def test_list_annotations_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_annotations + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_annotations + ] = mock_object + + request = {} + await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_annotations_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest +): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = dataset_service.ListAnnotationsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAnnotationsAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_annotations_async_from_dict(): + await test_list_annotations_async(request_type=dict) + + +def test_list_annotations_field_headers(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = dataset_service.ListAnnotationsResponse() + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_annotations_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) + await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_list_annotations_flattened(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_annotations( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +def test_list_annotations_flattened_error(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_annotations( + dataset_service.ListAnnotationsRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_annotations_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_annotations( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_annotations_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_annotations( + dataset_service.ListAnnotationsRequest(), + parent="parent_value", + ) + + +def test_list_annotations_pager(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], + next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + ], + next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_annotations(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, annotation.Annotation) for i in results) + + +def test_list_annotations_pages(transport_name: str = "grpc"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], + next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + ], + next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], + ), + RuntimeError, + ) + pages = list(client.list_annotations(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_annotations_async_pager(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], + next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + ], + next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_annotations( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, annotation.Annotation) for i in responses) + + +@pytest.mark.asyncio +async def test_list_annotations_async_pages(): + client = DatasetServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], + next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + ], + next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_annotations(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -@pytest.mark.parametrize( - "request_type", - [ - dataset_service.CreateDatasetRequest, - dict, - ], -) -def test_create_dataset_rest(request_type): +@pytest.mark.parametrize( + "request_type", + [ + dataset_service.CreateDatasetRequest, + dict, + ], +) +def test_create_dataset_rest(request_type): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["dataset"] = { + "name": "name_value", + "display_name": "display_name_value", + "description": "description_value", + "metadata_schema_uri": "metadata_schema_uri_value", + "metadata": { + "null_value": 0, + "number_value": 0.1285, + "string_value": "string_value_value", + "bool_value": True, + "struct_value": {"fields": {}}, + "list_value": {"values": {}}, + }, + "data_item_count": 1584, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "etag": "etag_value", + "labels": {}, + "saved_queries": [ + { + "name": "name_value", + "display_name": "display_name_value", + "metadata": {}, + "create_time": {}, + "update_time": {}, + "annotation_filter": "annotation_filter_value", + "problem_type": "problem_type_value", + "annotation_spec_count": 2253, + "etag": "etag_value", + "support_automl_training": True, + } + ], + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, + "metadata_artifact": "metadata_artifact_value", + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = dataset_service.CreateDatasetRequest.meta.fields["dataset"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["dataset"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["dataset"][field])): + del request_init["dataset"][field][i][subfield] + else: + del request_init["dataset"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_dataset(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_dataset] = mock_rpc + + request = {} + client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_dataset_rest_required_fields( + request_type=dataset_service.CreateDatasetRequest, +): + transport_class = transports.DatasetServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_dataset._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_dataset._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_dataset(request) + + expected_params = [] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_dataset_rest_unset_required_fields(): + transport = transports.DatasetServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_dataset._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "dataset", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_dataset_rest_interceptors(null_interceptor): + transport = transports.DatasetServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatasetServiceRestInterceptor(), + ) + client = DatasetServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.DatasetServiceRestInterceptor, "post_create_dataset" + ) as post, mock.patch.object( + transports.DatasetServiceRestInterceptor, "pre_create_dataset" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = dataset_service.CreateDatasetRequest.pb( + dataset_service.CreateDatasetRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = dataset_service.CreateDatasetRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.create_dataset( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_dataset_rest_bad_request( + transport: str = "rest", request_type=dataset_service.CreateDatasetRequest +): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport="rest", + transport=transport, ) # send a request that will satisfy transcoding request_init = {"parent": "projects/sample1/locations/sample2"} - request_init["dataset"] = { - "name": "name_value", - "display_name": "display_name_value", - "description": "description_value", - "metadata_schema_uri": "metadata_schema_uri_value", - "metadata": { - "null_value": 0, - "number_value": 0.1285, - "string_value": "string_value_value", - "bool_value": True, - "struct_value": {"fields": {}}, - "list_value": {"values": {}}, - }, - "data_item_count": 1584, - "create_time": {"seconds": 751, "nanos": 543}, - "update_time": {}, - "etag": "etag_value", - "labels": {}, - "saved_queries": [ - { - "name": "name_value", - "display_name": "display_name_value", - "metadata": {}, - "create_time": {}, - "update_time": {}, - "annotation_filter": "annotation_filter_value", - "problem_type": "problem_type_value", - "annotation_spec_count": 2253, - "etag": "etag_value", - "support_automl_training": True, - } - ], - "encryption_spec": {"kms_key_name": "kms_key_name_value"}, - "metadata_artifact": "metadata_artifact_value", - } - # The version of a generated dependency at test runtime may differ from the version used during generation. - # Delete any fields which are not present in the current runtime dependency - # See https://github.com/googleapis/gapic-generator-python/issues/1748 + request = request_type(**request_init) - # Determine if the message type is proto-plus or protobuf - test_field = dataset_service.CreateDatasetRequest.meta.fields["dataset"] + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_dataset(request) - def get_message_fields(field): - # Given a field which is a message (composite type), return a list with - # all the fields of the message. - # If the field is not a composite type, return an empty list. - message_fields = [] - if hasattr(field, "message") and field.message: - is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") +def test_create_dataset_rest_flattened(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) - if is_field_type_proto_plus_type: - message_fields = field.message.meta.fields.values() - # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types - else: # pragma: NO COVER - message_fields = field.message.DESCRIPTOR.fields - return message_fields + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") - runtime_nested_fields = [ - (field.name, nested_field.name) - for field in get_message_fields(test_field) - for nested_field in get_message_fields(field) - ] + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} - subfields_not_in_runtime = [] + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), + ) + mock_args.update(sample_request) - # For each item in the sample request, create a list of sub fields which are not present at runtime - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for field, value in request_init["dataset"].items(): # pragma: NO COVER - result = None - is_repeated = False - # For repeated fields - if isinstance(value, list) and len(value): - is_repeated = True - result = value[0] - # For fields where the type is another message - if isinstance(value, dict): - result = value + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value - if result and hasattr(result, "keys"): - for subfield in result.keys(): - if (field, subfield) not in runtime_nested_fields: - subfields_not_in_runtime.append( - { - "field": field, - "subfield": subfield, - "is_repeated": is_repeated, - } - ) + client.create_dataset(**mock_args) - # Remove fields from the sample request which are not present in the runtime version of the dependency - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER - field = subfield_to_delete.get("field") - field_repeated = subfield_to_delete.get("is_repeated") - subfield = subfield_to_delete.get("subfield") - if subfield: - if field_repeated: - for i in range(0, len(request_init["dataset"][field])): - del request_init["dataset"][field][i][subfield] - else: - del request_init["dataset"][field][subfield] + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1beta1/{parent=projects/*/locations/*}/datasets" + % client.transport._host, + args[1], + ) + + +def test_create_dataset_rest_flattened_error(transport: str = "rest"): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_dataset( + dataset_service.CreateDatasetRequest(), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), + ) + + +def test_create_dataset_rest_error(): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + dataset_service.GetDatasetRequest, + dict, + ], +) +def test_get_dataset_rest(request_type): + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/locations/sample2/datasets/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = dataset.Dataset( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + data_item_count=1584, + etag="etag_value", + metadata_artifact="metadata_artifact_value", + ) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = dataset.Dataset.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.create_dataset(request) + response = client.get_dataset(request) # Establish that the response is the type that we expect. - assert response.operation.name == "operations/spam" + assert isinstance(response, dataset.Dataset) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.data_item_count == 1584 + assert response.etag == "etag_value" + assert response.metadata_artifact == "metadata_artifact_value" -def test_create_dataset_rest_required_fields( - request_type=dataset_service.CreateDatasetRequest, +def test_get_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_dataset] = mock_rpc + + request = {} + client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_dataset_rest_required_fields( + request_type=dataset_service.GetDatasetRequest, ): transport_class = transports.DatasetServiceRestTransport request_init = {} - request_init["parent"] = "" + request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -7764,21 +10211,23 @@ def test_create_dataset_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_dataset._get_unset_required_fields(jsonified_request) + ).get_dataset._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["name"] = "name_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_dataset._get_unset_required_fields(jsonified_request) + ).get_dataset._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("read_mask",)) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7787,7 +10236,7 @@ def test_create_dataset_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = dataset.Dataset() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7799,45 +10248,39 @@ def test_create_dataset_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "get", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = dataset.Dataset.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.create_dataset(request) + response = client.get_dataset(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_create_dataset_rest_unset_required_fields(): +def test_get_dataset_rest_unset_required_fields(): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.create_dataset._get_unset_required_fields({}) - assert set(unset_fields) == ( - set(()) - & set( - ( - "parent", - "dataset", - ) - ) - ) + unset_fields = transport.get_dataset._get_unset_required_fields({}) + assert set(unset_fields) == (set(("readMask",)) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_create_dataset_rest_interceptors(null_interceptor): +def test_get_dataset_rest_interceptors(null_interceptor): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -7850,16 +10293,14 @@ def test_create_dataset_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.DatasetServiceRestInterceptor, "post_create_dataset" + transports.DatasetServiceRestInterceptor, "post_get_dataset" ) as post, mock.patch.object( - transports.DatasetServiceRestInterceptor, "pre_create_dataset" + transports.DatasetServiceRestInterceptor, "pre_get_dataset" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = dataset_service.CreateDatasetRequest.pb( - dataset_service.CreateDatasetRequest() + pb_message = dataset_service.GetDatasetRequest.pb( + dataset_service.GetDatasetRequest() ) transcode.return_value = { "method": "post", @@ -7871,19 +10312,17 @@ def test_create_dataset_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = json_format.MessageToJson( - operations_pb2.Operation() - ) + req.return_value._content = dataset.Dataset.to_json(dataset.Dataset()) - request = dataset_service.CreateDatasetRequest() + request = dataset_service.GetDatasetRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() + post.return_value = dataset.Dataset() - client.create_dataset( + client.get_dataset( request, metadata=[ ("key", "val"), @@ -7895,8 +10334,8 @@ def test_create_dataset_rest_interceptors(null_interceptor): post.assert_called_once() -def test_create_dataset_rest_bad_request( - transport: str = "rest", request_type=dataset_service.CreateDatasetRequest +def test_get_dataset_rest_bad_request( + transport: str = "rest", request_type=dataset_service.GetDatasetRequest ): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7904,7 +10343,7 @@ def test_create_dataset_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = {"name": "projects/sample1/locations/sample2/datasets/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -7916,10 +10355,10 @@ def test_create_dataset_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.create_dataset(request) + client.get_dataset(request) -def test_create_dataset_rest_flattened(): +def test_get_dataset_rest_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -7928,39 +10367,40 @@ def test_create_dataset_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = dataset.Dataset() # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = {"name": "projects/sample1/locations/sample2/datasets/sample3"} # get truthy value for each flattened field mock_args = dict( - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), + name="name_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = dataset.Dataset.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.create_dataset(**mock_args) + client.get_dataset(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/datasets" + "%s/v1beta1/{name=projects/*/locations/*/datasets/*}" % client.transport._host, args[1], ) -def test_create_dataset_rest_flattened_error(transport: str = "rest"): +def test_get_dataset_rest_flattened_error(transport: str = "rest"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7969,14 +10409,13 @@ def test_create_dataset_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.create_dataset( - dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), + client.get_dataset( + dataset_service.GetDatasetRequest(), + name="name_value", ) -def test_create_dataset_rest_error(): +def test_get_dataset_rest_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -7985,24 +10424,128 @@ def test_create_dataset_rest_error(): @pytest.mark.parametrize( "request_type", [ - dataset_service.GetDatasetRequest, + dataset_service.UpdateDatasetRequest, dict, ], ) -def test_get_dataset_rest(request_type): +def test_update_dataset_rest(request_type): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/datasets/sample3"} + request_init = { + "dataset": {"name": "projects/sample1/locations/sample2/datasets/sample3"} + } + request_init["dataset"] = { + "name": "projects/sample1/locations/sample2/datasets/sample3", + "display_name": "display_name_value", + "description": "description_value", + "metadata_schema_uri": "metadata_schema_uri_value", + "metadata": { + "null_value": 0, + "number_value": 0.1285, + "string_value": "string_value_value", + "bool_value": True, + "struct_value": {"fields": {}}, + "list_value": {"values": {}}, + }, + "data_item_count": 1584, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "etag": "etag_value", + "labels": {}, + "saved_queries": [ + { + "name": "name_value", + "display_name": "display_name_value", + "metadata": {}, + "create_time": {}, + "update_time": {}, + "annotation_filter": "annotation_filter_value", + "problem_type": "problem_type_value", + "annotation_spec_count": 2253, + "etag": "etag_value", + "support_automl_training": True, + } + ], + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, + "metadata_artifact": "metadata_artifact_value", + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = dataset_service.UpdateDatasetRequest.meta.fields["dataset"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["dataset"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["dataset"][field])): + del request_init["dataset"][field][i][subfield] + else: + del request_init["dataset"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = dataset.Dataset( + return_value = gca_dataset.Dataset( name="name_value", display_name="display_name_value", description="description_value", @@ -8016,15 +10559,15 @@ def test_get_dataset_rest(request_type): response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = dataset.Dataset.pb(return_value) + return_value = gca_dataset.Dataset.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.get_dataset(request) + response = client.update_dataset(request) # Establish that the response is the type that we expect. - assert isinstance(response, dataset.Dataset) + assert isinstance(response, gca_dataset.Dataset) assert response.name == "name_value" assert response.display_name == "display_name_value" assert response.description == "description_value" @@ -8034,13 +10577,48 @@ def test_get_dataset_rest(request_type): assert response.metadata_artifact == "metadata_artifact_value" -def test_get_dataset_rest_required_fields( - request_type=dataset_service.GetDatasetRequest, +def test_update_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_dataset] = mock_rpc + + request = {} + client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_dataset_rest_required_fields( + request_type=dataset_service.UpdateDatasetRequest, ): transport_class = transports.DatasetServiceRestTransport request_init = {} - request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -8051,23 +10629,19 @@ def test_get_dataset_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_dataset._get_unset_required_fields(jsonified_request) + ).update_dataset._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = "name_value" - unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_dataset._get_unset_required_fields(jsonified_request) + ).update_dataset._get_unset_required_fields(jsonified_request) # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set(("read_mask",)) + assert not set(unset_fields) - set(("update_mask",)) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "name" in jsonified_request - assert jsonified_request["name"] == "name_value" client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -8076,7 +10650,7 @@ def test_get_dataset_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = dataset.Dataset() + return_value = gca_dataset.Dataset() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -8088,39 +10662,48 @@ def test_get_dataset_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "patch", "query_params": pb_request, } + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = dataset.Dataset.pb(return_value) + return_value = gca_dataset.Dataset.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.get_dataset(request) + response = client.update_dataset(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_get_dataset_rest_unset_required_fields(): +def test_update_dataset_rest_unset_required_fields(): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.get_dataset._get_unset_required_fields({}) - assert set(unset_fields) == (set(("readMask",)) & set(("name",))) + unset_fields = transport.update_dataset._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "dataset", + "updateMask", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_get_dataset_rest_interceptors(null_interceptor): +def test_update_dataset_rest_interceptors(null_interceptor): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8133,14 +10716,14 @@ def test_get_dataset_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.DatasetServiceRestInterceptor, "post_get_dataset" + transports.DatasetServiceRestInterceptor, "post_update_dataset" ) as post, mock.patch.object( - transports.DatasetServiceRestInterceptor, "pre_get_dataset" + transports.DatasetServiceRestInterceptor, "pre_update_dataset" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = dataset_service.GetDatasetRequest.pb( - dataset_service.GetDatasetRequest() + pb_message = dataset_service.UpdateDatasetRequest.pb( + dataset_service.UpdateDatasetRequest() ) transcode.return_value = { "method": "post", @@ -8152,17 +10735,17 @@ def test_get_dataset_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = dataset.Dataset.to_json(dataset.Dataset()) + req.return_value._content = gca_dataset.Dataset.to_json(gca_dataset.Dataset()) - request = dataset_service.GetDatasetRequest() + request = dataset_service.UpdateDatasetRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = dataset.Dataset() + post.return_value = gca_dataset.Dataset() - client.get_dataset( + client.update_dataset( request, metadata=[ ("key", "val"), @@ -8174,8 +10757,8 @@ def test_get_dataset_rest_interceptors(null_interceptor): post.assert_called_once() -def test_get_dataset_rest_bad_request( - transport: str = "rest", request_type=dataset_service.GetDatasetRequest +def test_update_dataset_rest_bad_request( + transport: str = "rest", request_type=dataset_service.UpdateDatasetRequest ): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -8183,7 +10766,9 @@ def test_get_dataset_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/datasets/sample3"} + request_init = { + "dataset": {"name": "projects/sample1/locations/sample2/datasets/sample3"} + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8195,10 +10780,10 @@ def test_get_dataset_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.get_dataset(request) + client.update_dataset(request) -def test_get_dataset_rest_flattened(): +def test_update_dataset_rest_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -8207,14 +10792,17 @@ def test_get_dataset_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = dataset.Dataset() + return_value = gca_dataset.Dataset() # get arguments that satisfy an http rule for this method - sample_request = {"name": "projects/sample1/locations/sample2/datasets/sample3"} + sample_request = { + "dataset": {"name": "projects/sample1/locations/sample2/datasets/sample3"} + } # get truthy value for each flattened field mock_args = dict( - name="name_value", + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) mock_args.update(sample_request) @@ -8222,25 +10810,25 @@ def test_get_dataset_rest_flattened(): response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = dataset.Dataset.pb(return_value) + return_value = gca_dataset.Dataset.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.get_dataset(**mock_args) + client.update_dataset(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/datasets/*}" + "%s/v1beta1/{dataset.name=projects/*/locations/*/datasets/*}" % client.transport._host, args[1], ) -def test_get_dataset_rest_flattened_error(transport: str = "rest"): +def test_update_dataset_rest_flattened_error(transport: str = "rest"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -8249,13 +10837,14 @@ def test_get_dataset_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_dataset( - dataset_service.GetDatasetRequest(), - name="name_value", + client.update_dataset( + dataset_service.UpdateDatasetRequest(), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) -def test_get_dataset_rest_error(): +def test_update_dataset_rest_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -8264,165 +10853,86 @@ def test_get_dataset_rest_error(): @pytest.mark.parametrize( "request_type", [ - dataset_service.UpdateDatasetRequest, + dataset_service.ListDatasetsRequest, dict, ], ) -def test_update_dataset_rest(request_type): +def test_list_datasets_rest(request_type): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = { - "dataset": {"name": "projects/sample1/locations/sample2/datasets/sample3"} - } - request_init["dataset"] = { - "name": "projects/sample1/locations/sample2/datasets/sample3", - "display_name": "display_name_value", - "description": "description_value", - "metadata_schema_uri": "metadata_schema_uri_value", - "metadata": { - "null_value": 0, - "number_value": 0.1285, - "string_value": "string_value_value", - "bool_value": True, - "struct_value": {"fields": {}}, - "list_value": {"values": {}}, - }, - "data_item_count": 1584, - "create_time": {"seconds": 751, "nanos": 543}, - "update_time": {}, - "etag": "etag_value", - "labels": {}, - "saved_queries": [ - { - "name": "name_value", - "display_name": "display_name_value", - "metadata": {}, - "create_time": {}, - "update_time": {}, - "annotation_filter": "annotation_filter_value", - "problem_type": "problem_type_value", - "annotation_spec_count": 2253, - "etag": "etag_value", - "support_automl_training": True, - } - ], - "encryption_spec": {"kms_key_name": "kms_key_name_value"}, - "metadata_artifact": "metadata_artifact_value", - } - # The version of a generated dependency at test runtime may differ from the version used during generation. - # Delete any fields which are not present in the current runtime dependency - # See https://github.com/googleapis/gapic-generator-python/issues/1748 - - # Determine if the message type is proto-plus or protobuf - test_field = dataset_service.UpdateDatasetRequest.meta.fields["dataset"] - - def get_message_fields(field): - # Given a field which is a message (composite type), return a list with - # all the fields of the message. - # If the field is not a composite type, return an empty list. - message_fields = [] - - if hasattr(field, "message") and field.message: - is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") - - if is_field_type_proto_plus_type: - message_fields = field.message.meta.fields.values() - # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types - else: # pragma: NO COVER - message_fields = field.message.DESCRIPTOR.fields - return message_fields - - runtime_nested_fields = [ - (field.name, nested_field.name) - for field in get_message_fields(test_field) - for nested_field in get_message_fields(field) - ] - - subfields_not_in_runtime = [] - - # For each item in the sample request, create a list of sub fields which are not present at runtime - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for field, value in request_init["dataset"].items(): # pragma: NO COVER - result = None - is_repeated = False - # For repeated fields - if isinstance(value, list) and len(value): - is_repeated = True - result = value[0] - # For fields where the type is another message - if isinstance(value, dict): - result = value - - if result and hasattr(result, "keys"): - for subfield in result.keys(): - if (field, subfield) not in runtime_nested_fields: - subfields_not_in_runtime.append( - { - "field": field, - "subfield": subfield, - "is_repeated": is_repeated, - } - ) - - # Remove fields from the sample request which are not present in the runtime version of the dependency - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER - field = subfield_to_delete.get("field") - field_repeated = subfield_to_delete.get("is_repeated") - subfield = subfield_to_delete.get("subfield") - if subfield: - if field_repeated: - for i in range(0, len(request_init["dataset"][field])): - del request_init["dataset"][field][i][subfield] - else: - del request_init["dataset"][field][subfield] + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - data_item_count=1584, - etag="etag_value", - metadata_artifact="metadata_artifact_value", + return_value = dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", ) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = gca_dataset.Dataset.pb(return_value) + return_value = dataset_service.ListDatasetsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.update_dataset(request) + response = client.list_datasets(request) # Establish that the response is the type that we expect. - assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.data_item_count == 1584 - assert response.etag == "etag_value" - assert response.metadata_artifact == "metadata_artifact_value" + assert isinstance(response, pagers.ListDatasetsPager) + assert response.next_page_token == "next_page_token_value" -def test_update_dataset_rest_required_fields( - request_type=dataset_service.UpdateDatasetRequest, +def test_list_datasets_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_datasets in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_datasets] = mock_rpc + + request = {} + client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_datasets(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_datasets_rest_required_fields( + request_type=dataset_service.ListDatasetsRequest, ): transport_class = transports.DatasetServiceRestTransport request_init = {} + request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -8433,19 +10943,31 @@ def test_update_dataset_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).update_dataset._get_unset_required_fields(jsonified_request) + ).list_datasets._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present + jsonified_request["parent"] = "parent_value" + unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).update_dataset._get_unset_required_fields(jsonified_request) + ).list_datasets._get_unset_required_fields(jsonified_request) # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set(("update_mask",)) + assert not set(unset_fields) - set( + ( + "filter", + "order_by", + "page_size", + "page_token", + "read_mask", + ) + ) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -8454,7 +10976,7 @@ def test_update_dataset_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = gca_dataset.Dataset() + return_value = dataset_service.ListDatasetsResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -8466,48 +10988,50 @@ def test_update_dataset_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "patch", + "method": "get", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = gca_dataset.Dataset.pb(return_value) + return_value = dataset_service.ListDatasetsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.update_dataset(request) + response = client.list_datasets(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_update_dataset_rest_unset_required_fields(): +def test_list_datasets_rest_unset_required_fields(): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.update_dataset._get_unset_required_fields({}) + unset_fields = transport.list_datasets._get_unset_required_fields({}) assert set(unset_fields) == ( - set(("updateMask",)) - & set( + set( ( - "dataset", - "updateMask", + "filter", + "orderBy", + "pageSize", + "pageToken", + "readMask", ) ) + & set(("parent",)) ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_update_dataset_rest_interceptors(null_interceptor): +def test_list_datasets_rest_interceptors(null_interceptor): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8520,14 +11044,14 @@ def test_update_dataset_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.DatasetServiceRestInterceptor, "post_update_dataset" + transports.DatasetServiceRestInterceptor, "post_list_datasets" ) as post, mock.patch.object( - transports.DatasetServiceRestInterceptor, "pre_update_dataset" + transports.DatasetServiceRestInterceptor, "pre_list_datasets" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = dataset_service.UpdateDatasetRequest.pb( - dataset_service.UpdateDatasetRequest() + pb_message = dataset_service.ListDatasetsRequest.pb( + dataset_service.ListDatasetsRequest() ) transcode.return_value = { "method": "post", @@ -8539,17 +11063,19 @@ def test_update_dataset_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = gca_dataset.Dataset.to_json(gca_dataset.Dataset()) + req.return_value._content = dataset_service.ListDatasetsResponse.to_json( + dataset_service.ListDatasetsResponse() + ) - request = dataset_service.UpdateDatasetRequest() + request = dataset_service.ListDatasetsRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = gca_dataset.Dataset() + post.return_value = dataset_service.ListDatasetsResponse() - client.update_dataset( + client.list_datasets( request, metadata=[ ("key", "val"), @@ -8561,8 +11087,8 @@ def test_update_dataset_rest_interceptors(null_interceptor): post.assert_called_once() -def test_update_dataset_rest_bad_request( - transport: str = "rest", request_type=dataset_service.UpdateDatasetRequest +def test_list_datasets_rest_bad_request( + transport: str = "rest", request_type=dataset_service.ListDatasetsRequest ): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -8570,9 +11096,7 @@ def test_update_dataset_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = { - "dataset": {"name": "projects/sample1/locations/sample2/datasets/sample3"} - } + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8584,10 +11108,10 @@ def test_update_dataset_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.update_dataset(request) + client.list_datasets(request) -def test_update_dataset_rest_flattened(): +def test_list_datasets_rest_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -8596,17 +11120,14 @@ def test_update_dataset_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = gca_dataset.Dataset() + return_value = dataset_service.ListDatasetsResponse() # get arguments that satisfy an http rule for this method - sample_request = { - "dataset": {"name": "projects/sample1/locations/sample2/datasets/sample3"} - } + sample_request = {"parent": "projects/sample1/locations/sample2"} # get truthy value for each flattened field mock_args = dict( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + parent="parent_value", ) mock_args.update(sample_request) @@ -8614,25 +11135,25 @@ def test_update_dataset_rest_flattened(): response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = gca_dataset.Dataset.pb(return_value) + return_value = dataset_service.ListDatasetsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.update_dataset(**mock_args) + client.list_datasets(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{dataset.name=projects/*/locations/*/datasets/*}" + "%s/v1beta1/{parent=projects/*/locations/*}/datasets" % client.transport._host, args[1], ) -def test_update_dataset_rest_flattened_error(transport: str = "rest"): +def test_list_datasets_rest_flattened_error(transport: str = "rest"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -8641,66 +11162,157 @@ def test_update_dataset_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.update_dataset( - dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + client.list_datasets( + dataset_service.ListDatasetsRequest(), + parent="parent_value", ) -def test_update_dataset_rest_error(): +def test_list_datasets_rest_pager(transport: str = "rest"): client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), transport="rest" + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token="def", + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + ], + next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + dataset_service.ListDatasetsResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/locations/sample2"} + + pager = client.list_datasets(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, dataset.Dataset) for i in results) + + pages = list(client.list_datasets(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + @pytest.mark.parametrize( "request_type", [ - dataset_service.ListDatasetsRequest, + dataset_service.DeleteDatasetRequest, dict, ], ) -def test_list_datasets_rest(request_type): +def test_delete_dataset_rest(request_type): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = {"name": "projects/sample1/locations/sample2/datasets/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = dataset_service.ListDatasetsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.list_datasets(request) + response = client.delete_dataset(request) # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == "next_page_token_value" + assert response.operation.name == "operations/spam" -def test_list_datasets_rest_required_fields( - request_type=dataset_service.ListDatasetsRequest, +def test_delete_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_dataset in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_dataset] = mock_rpc + + request = {} + client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_dataset_rest_required_fields( + request_type=dataset_service.DeleteDatasetRequest, ): transport_class = transports.DatasetServiceRestTransport request_init = {} - request_init["parent"] = "" + request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -8711,31 +11323,21 @@ def test_list_datasets_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_datasets._get_unset_required_fields(jsonified_request) + ).delete_dataset._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["name"] = "name_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_datasets._get_unset_required_fields(jsonified_request) - # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set( - ( - "filter", - "order_by", - "page_size", - "page_token", - "read_mask", - ) - ) + ).delete_dataset._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -8744,7 +11346,7 @@ def test_list_datasets_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = dataset_service.ListDatasetsResponse() + return_value = operations_pb2.Operation(name="operations/spam") # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -8756,50 +11358,36 @@ def test_list_datasets_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "delete", "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = dataset_service.ListDatasetsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.list_datasets(request) + response = client.delete_dataset(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_list_datasets_rest_unset_required_fields(): +def test_delete_dataset_rest_unset_required_fields(): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.list_datasets._get_unset_required_fields({}) - assert set(unset_fields) == ( - set( - ( - "filter", - "orderBy", - "pageSize", - "pageToken", - "readMask", - ) - ) - & set(("parent",)) - ) + unset_fields = transport.delete_dataset._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_list_datasets_rest_interceptors(null_interceptor): +def test_delete_dataset_rest_interceptors(null_interceptor): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8812,14 +11400,16 @@ def test_list_datasets_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.DatasetServiceRestInterceptor, "post_list_datasets" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.DatasetServiceRestInterceptor, "post_delete_dataset" ) as post, mock.patch.object( - transports.DatasetServiceRestInterceptor, "pre_list_datasets" + transports.DatasetServiceRestInterceptor, "pre_delete_dataset" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = dataset_service.ListDatasetsRequest.pb( - dataset_service.ListDatasetsRequest() + pb_message = dataset_service.DeleteDatasetRequest.pb( + dataset_service.DeleteDatasetRequest() ) transcode.return_value = { "method": "post", @@ -8831,19 +11421,19 @@ def test_list_datasets_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = dataset_service.ListDatasetsResponse.to_json( - dataset_service.ListDatasetsResponse() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() ) - request = dataset_service.ListDatasetsRequest() + request = dataset_service.DeleteDatasetRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = dataset_service.ListDatasetsResponse() + post.return_value = operations_pb2.Operation() - client.list_datasets( + client.delete_dataset( request, metadata=[ ("key", "val"), @@ -8855,8 +11445,8 @@ def test_list_datasets_rest_interceptors(null_interceptor): post.assert_called_once() -def test_list_datasets_rest_bad_request( - transport: str = "rest", request_type=dataset_service.ListDatasetsRequest +def test_delete_dataset_rest_bad_request( + transport: str = "rest", request_type=dataset_service.DeleteDatasetRequest ): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -8864,7 +11454,7 @@ def test_list_datasets_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = {"name": "projects/sample1/locations/sample2/datasets/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8876,10 +11466,10 @@ def test_list_datasets_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.list_datasets(request) + client.delete_dataset(request) -def test_list_datasets_rest_flattened(): +def test_delete_dataset_rest_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -8888,40 +11478,38 @@ def test_list_datasets_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = dataset_service.ListDatasetsResponse() + return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = {"name": "projects/sample1/locations/sample2/datasets/sample3"} # get truthy value for each flattened field mock_args = dict( - parent="parent_value", + name="name_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = dataset_service.ListDatasetsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.list_datasets(**mock_args) + client.delete_dataset(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/datasets" + "%s/v1beta1/{name=projects/*/locations/*/datasets/*}" % client.transport._host, args[1], ) -def test_list_datasets_rest_flattened_error(transport: str = "rest"): +def test_delete_dataset_rest_flattened_error(transport: str = "rest"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -8930,83 +11518,26 @@ def test_list_datasets_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent="parent_value", + client.delete_dataset( + dataset_service.DeleteDatasetRequest(), + name="name_value", ) -def test_list_datasets_rest_pager(transport: str = "rest"): +def test_delete_dataset_rest_error(): client = DatasetServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - # with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token="def", - ), - dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], - ), - ) - # Two responses for two calls - response = response + response - - # Wrap the values into proper Response objs - response = tuple( - dataset_service.ListDatasetsResponse.to_json(x) for x in response - ) - return_values = tuple(Response() for i in response) - for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode("UTF-8") - return_val.status_code = 200 - req.side_effect = return_values - - sample_request = {"parent": "projects/sample1/locations/sample2"} - - pager = client.list_datasets(request=sample_request) - - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) for i in results) - - pages = list(client.list_datasets(request=sample_request).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - @pytest.mark.parametrize( "request_type", [ - dataset_service.DeleteDatasetRequest, + dataset_service.ImportDataRequest, dict, ], ) -def test_delete_dataset_rest(request_type): +def test_import_data_rest(request_type): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -9028,14 +11559,54 @@ def test_delete_dataset_rest(request_type): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.delete_dataset(request) + response = client.import_data(request) # Establish that the response is the type that we expect. assert response.operation.name == "operations/spam" -def test_delete_dataset_rest_required_fields( - request_type=dataset_service.DeleteDatasetRequest, +def test_import_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_data in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.import_data] = mock_rpc + + request = {} + client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_import_data_rest_required_fields( + request_type=dataset_service.ImportDataRequest, ): transport_class = transports.DatasetServiceRestTransport @@ -9051,7 +11622,7 @@ def test_delete_dataset_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_dataset._get_unset_required_fields(jsonified_request) + ).import_data._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -9060,7 +11631,7 @@ def test_delete_dataset_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_dataset._get_unset_required_fields(jsonified_request) + ).import_data._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -9086,9 +11657,10 @@ def test_delete_dataset_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "delete", + "method": "post", "query_params": pb_request, } + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -9098,24 +11670,32 @@ def test_delete_dataset_rest_required_fields( response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.delete_dataset(request) + response = client.import_data(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_delete_dataset_rest_unset_required_fields(): +def test_import_data_rest_unset_required_fields(): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.delete_dataset._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name",))) + unset_fields = transport.import_data._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "name", + "importConfigs", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_delete_dataset_rest_interceptors(null_interceptor): +def test_import_data_rest_interceptors(null_interceptor): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -9130,14 +11710,14 @@ def test_delete_dataset_rest_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.DatasetServiceRestInterceptor, "post_delete_dataset" + transports.DatasetServiceRestInterceptor, "post_import_data" ) as post, mock.patch.object( - transports.DatasetServiceRestInterceptor, "pre_delete_dataset" + transports.DatasetServiceRestInterceptor, "pre_import_data" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = dataset_service.DeleteDatasetRequest.pb( - dataset_service.DeleteDatasetRequest() + pb_message = dataset_service.ImportDataRequest.pb( + dataset_service.ImportDataRequest() ) transcode.return_value = { "method": "post", @@ -9153,7 +11733,7 @@ def test_delete_dataset_rest_interceptors(null_interceptor): operations_pb2.Operation() ) - request = dataset_service.DeleteDatasetRequest() + request = dataset_service.ImportDataRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -9161,7 +11741,7 @@ def test_delete_dataset_rest_interceptors(null_interceptor): pre.return_value = request, metadata post.return_value = operations_pb2.Operation() - client.delete_dataset( + client.import_data( request, metadata=[ ("key", "val"), @@ -9173,8 +11753,8 @@ def test_delete_dataset_rest_interceptors(null_interceptor): post.assert_called_once() -def test_delete_dataset_rest_bad_request( - transport: str = "rest", request_type=dataset_service.DeleteDatasetRequest +def test_import_data_rest_bad_request( + transport: str = "rest", request_type=dataset_service.ImportDataRequest ): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -9194,10 +11774,10 @@ def test_delete_dataset_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.delete_dataset(request) + client.import_data(request) -def test_delete_dataset_rest_flattened(): +def test_import_data_rest_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -9214,6 +11794,9 @@ def test_delete_dataset_rest_flattened(): # get truthy value for each flattened field mock_args = dict( name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) mock_args.update(sample_request) @@ -9224,20 +11807,20 @@ def test_delete_dataset_rest_flattened(): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.delete_dataset(**mock_args) + client.import_data(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/datasets/*}" + "%s/v1beta1/{name=projects/*/locations/*/datasets/*}:import" % client.transport._host, args[1], ) -def test_delete_dataset_rest_flattened_error(transport: str = "rest"): +def test_import_data_rest_flattened_error(transport: str = "rest"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -9246,13 +11829,16 @@ def test_delete_dataset_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_dataset( - dataset_service.DeleteDatasetRequest(), + client.import_data( + dataset_service.ImportDataRequest(), name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) -def test_delete_dataset_rest_error(): +def test_import_data_rest_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -9261,11 +11847,11 @@ def test_delete_dataset_rest_error(): @pytest.mark.parametrize( "request_type", [ - dataset_service.ImportDataRequest, + dataset_service.ExportDataRequest, dict, ], ) -def test_import_data_rest(request_type): +def test_export_data_rest(request_type): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -9287,14 +11873,54 @@ def test_import_data_rest(request_type): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.import_data(request) + response = client.export_data(request) # Establish that the response is the type that we expect. assert response.operation.name == "operations/spam" -def test_import_data_rest_required_fields( - request_type=dataset_service.ImportDataRequest, +def test_export_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_data in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_data] = mock_rpc + + request = {} + client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_export_data_rest_required_fields( + request_type=dataset_service.ExportDataRequest, ): transport_class = transports.DatasetServiceRestTransport @@ -9310,7 +11936,7 @@ def test_import_data_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).import_data._get_unset_required_fields(jsonified_request) + ).export_data._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -9319,7 +11945,7 @@ def test_import_data_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).import_data._get_unset_required_fields(jsonified_request) + ).export_data._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -9358,32 +11984,32 @@ def test_import_data_rest_required_fields( response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.import_data(request) + response = client.export_data(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_import_data_rest_unset_required_fields(): +def test_export_data_rest_unset_required_fields(): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.import_data._get_unset_required_fields({}) + unset_fields = transport.export_data._get_unset_required_fields({}) assert set(unset_fields) == ( set(()) & set( ( "name", - "importConfigs", + "exportConfig", ) ) ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_import_data_rest_interceptors(null_interceptor): +def test_export_data_rest_interceptors(null_interceptor): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -9398,14 +12024,14 @@ def test_import_data_rest_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.DatasetServiceRestInterceptor, "post_import_data" + transports.DatasetServiceRestInterceptor, "post_export_data" ) as post, mock.patch.object( - transports.DatasetServiceRestInterceptor, "pre_import_data" + transports.DatasetServiceRestInterceptor, "pre_export_data" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = dataset_service.ImportDataRequest.pb( - dataset_service.ImportDataRequest() + pb_message = dataset_service.ExportDataRequest.pb( + dataset_service.ExportDataRequest() ) transcode.return_value = { "method": "post", @@ -9421,7 +12047,7 @@ def test_import_data_rest_interceptors(null_interceptor): operations_pb2.Operation() ) - request = dataset_service.ImportDataRequest() + request = dataset_service.ExportDataRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -9429,7 +12055,7 @@ def test_import_data_rest_interceptors(null_interceptor): pre.return_value = request, metadata post.return_value = operations_pb2.Operation() - client.import_data( + client.export_data( request, metadata=[ ("key", "val"), @@ -9441,8 +12067,8 @@ def test_import_data_rest_interceptors(null_interceptor): post.assert_called_once() -def test_import_data_rest_bad_request( - transport: str = "rest", request_type=dataset_service.ImportDataRequest +def test_export_data_rest_bad_request( + transport: str = "rest", request_type=dataset_service.ExportDataRequest ): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -9462,10 +12088,10 @@ def test_import_data_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.import_data(request) + client.export_data(request) -def test_import_data_rest_flattened(): +def test_export_data_rest_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -9482,9 +12108,11 @@ def test_import_data_rest_flattened(): # get truthy value for each flattened field mock_args = dict( name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) mock_args.update(sample_request) @@ -9495,20 +12123,20 @@ def test_import_data_rest_flattened(): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.import_data(**mock_args) + client.export_data(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/datasets/*}:import" + "%s/v1beta1/{name=projects/*/locations/*/datasets/*}:export" % client.transport._host, args[1], ) -def test_import_data_rest_flattened_error(transport: str = "rest"): +def test_export_data_rest_flattened_error(transport: str = "rest"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -9517,16 +12145,18 @@ def test_import_data_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.import_data( - dataset_service.ImportDataRequest(), + client.export_data( + dataset_service.ExportDataRequest(), name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) -def test_import_data_rest_error(): +def test_export_data_rest_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -9535,18 +12165,103 @@ def test_import_data_rest_error(): @pytest.mark.parametrize( "request_type", [ - dataset_service.ExportDataRequest, + dataset_service.CreateDatasetVersionRequest, dict, ], ) -def test_export_data_rest(request_type): +def test_create_dataset_version_rest(request_type): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/datasets/sample3"} + request_init = {"parent": "projects/sample1/locations/sample2/datasets/sample3"} + request_init["dataset_version"] = { + "name": "name_value", + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "etag": "etag_value", + "big_query_dataset_name": "big_query_dataset_name_value", + "display_name": "display_name_value", + "metadata": { + "null_value": 0, + "number_value": 0.1285, + "string_value": "string_value_value", + "bool_value": True, + "struct_value": {"fields": {}}, + "list_value": {"values": {}}, + }, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = dataset_service.CreateDatasetVersionRequest.meta.fields[ + "dataset_version" + ] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["dataset_version"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["dataset_version"][field])): + del request_init["dataset_version"][field][i][subfield] + else: + del request_init["dataset_version"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -9561,19 +12276,64 @@ def test_export_data_rest(request_type): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.export_data(request) + response = client.create_dataset_version(request) # Establish that the response is the type that we expect. assert response.operation.name == "operations/spam" -def test_export_data_rest_required_fields( - request_type=dataset_service.ExportDataRequest, +def test_create_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_dataset_version + ] = mock_rpc + + request = {} + client.create_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_dataset_version_rest_required_fields( + request_type=dataset_service.CreateDatasetVersionRequest, ): transport_class = transports.DatasetServiceRestTransport request_init = {} - request_init["name"] = "" + request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -9584,21 +12344,21 @@ def test_export_data_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).export_data._get_unset_required_fields(jsonified_request) + ).create_dataset_version._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = "name_value" + jsonified_request["parent"] = "parent_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).export_data._get_unset_required_fields(jsonified_request) + ).create_dataset_version._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "name" in jsonified_request - assert jsonified_request["name"] == "name_value" + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -9632,32 +12392,32 @@ def test_export_data_rest_required_fields( response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.export_data(request) + response = client.create_dataset_version(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_export_data_rest_unset_required_fields(): +def test_create_dataset_version_rest_unset_required_fields(): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.export_data._get_unset_required_fields({}) + unset_fields = transport.create_dataset_version._get_unset_required_fields({}) assert set(unset_fields) == ( set(()) & set( ( - "name", - "exportConfig", + "parent", + "datasetVersion", ) ) ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_export_data_rest_interceptors(null_interceptor): +def test_create_dataset_version_rest_interceptors(null_interceptor): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -9672,14 +12432,14 @@ def test_export_data_rest_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.DatasetServiceRestInterceptor, "post_export_data" + transports.DatasetServiceRestInterceptor, "post_create_dataset_version" ) as post, mock.patch.object( - transports.DatasetServiceRestInterceptor, "pre_export_data" + transports.DatasetServiceRestInterceptor, "pre_create_dataset_version" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = dataset_service.ExportDataRequest.pb( - dataset_service.ExportDataRequest() + pb_message = dataset_service.CreateDatasetVersionRequest.pb( + dataset_service.CreateDatasetVersionRequest() ) transcode.return_value = { "method": "post", @@ -9695,7 +12455,7 @@ def test_export_data_rest_interceptors(null_interceptor): operations_pb2.Operation() ) - request = dataset_service.ExportDataRequest() + request = dataset_service.CreateDatasetVersionRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -9703,7 +12463,7 @@ def test_export_data_rest_interceptors(null_interceptor): pre.return_value = request, metadata post.return_value = operations_pb2.Operation() - client.export_data( + client.create_dataset_version( request, metadata=[ ("key", "val"), @@ -9715,8 +12475,8 @@ def test_export_data_rest_interceptors(null_interceptor): post.assert_called_once() -def test_export_data_rest_bad_request( - transport: str = "rest", request_type=dataset_service.ExportDataRequest +def test_create_dataset_version_rest_bad_request( + transport: str = "rest", request_type=dataset_service.CreateDatasetVersionRequest ): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -9724,7 +12484,7 @@ def test_export_data_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/datasets/sample3"} + request_init = {"parent": "projects/sample1/locations/sample2/datasets/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -9736,10 +12496,10 @@ def test_export_data_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.export_data(request) + client.create_dataset_version(request) -def test_export_data_rest_flattened(): +def test_create_dataset_version_rest_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -9751,16 +12511,14 @@ def test_export_data_rest_flattened(): return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = {"name": "projects/sample1/locations/sample2/datasets/sample3"} + sample_request = { + "parent": "projects/sample1/locations/sample2/datasets/sample3" + } # get truthy value for each flattened field mock_args = dict( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + parent="parent_value", + dataset_version=gca_dataset_version.DatasetVersion(name="name_value"), ) mock_args.update(sample_request) @@ -9771,20 +12529,20 @@ def test_export_data_rest_flattened(): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.export_data(**mock_args) + client.create_dataset_version(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/datasets/*}:export" + "%s/v1beta1/{parent=projects/*/locations/*/datasets/*}/datasetVersions" % client.transport._host, args[1], ) -def test_export_data_rest_flattened_error(transport: str = "rest"): +def test_create_dataset_version_rest_flattened_error(transport: str = "rest"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -9793,18 +12551,14 @@ def test_export_data_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.export_data( - dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + client.create_dataset_version( + dataset_service.CreateDatasetVersionRequest(), + parent="parent_value", + dataset_version=gca_dataset_version.DatasetVersion(name="name_value"), ) -def test_export_data_rest_error(): +def test_create_dataset_version_rest_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -9813,20 +12567,24 @@ def test_export_data_rest_error(): @pytest.mark.parametrize( "request_type", [ - dataset_service.CreateDatasetVersionRequest, + dataset_service.UpdateDatasetVersionRequest, dict, ], ) -def test_create_dataset_version_rest(request_type): +def test_update_dataset_version_rest(request_type): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/datasets/sample3"} + request_init = { + "dataset_version": { + "name": "projects/sample1/locations/sample2/datasets/sample3/datasetVersions/sample4" + } + } request_init["dataset_version"] = { - "name": "name_value", + "name": "projects/sample1/locations/sample2/datasets/sample3/datasetVersions/sample4", "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, "etag": "etag_value", @@ -9846,7 +12604,7 @@ def test_create_dataset_version_rest(request_type): # See https://github.com/googleapis/gapic-generator-python/issues/1748 # Determine if the message type is proto-plus or protobuf - test_field = dataset_service.CreateDatasetVersionRequest.meta.fields[ + test_field = dataset_service.UpdateDatasetVersionRequest.meta.fields[ "dataset_version" ] @@ -9912,31 +12670,82 @@ def get_message_fields(field): del request_init["dataset_version"][field][subfield] request = request_type(**request_init) - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), "request") as req: - # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gca_dataset_version.DatasetVersion( + name="name_value", + etag="etag_value", + big_query_dataset_name="big_query_dataset_name_value", + display_name="display_name_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gca_dataset_version.DatasetVersion.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.update_dataset_version(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_dataset_version.DatasetVersion) + assert response.name == "name_value" + assert response.etag == "etag_value" + assert response.big_query_dataset_name == "big_query_dataset_name_value" + assert response.display_name == "display_name_value" + + +def test_update_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - json_return_value = json_format.MessageToJson(return_value) + # Ensure method has been cached + assert ( + client._transport.update_dataset_version + in client._transport._wrapped_methods + ) - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.create_dataset_version(request) + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_dataset_version + ] = mock_rpc - # Establish that the response is the type that we expect. - assert response.operation.name == "operations/spam" + request = {} + client.update_dataset_version(request) + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 -def test_create_dataset_version_rest_required_fields( - request_type=dataset_service.CreateDatasetVersionRequest, + client.update_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_dataset_version_rest_required_fields( + request_type=dataset_service.UpdateDatasetVersionRequest, ): transport_class = transports.DatasetServiceRestTransport request_init = {} - request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -9947,21 +12756,19 @@ def test_create_dataset_version_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_dataset_version._get_unset_required_fields(jsonified_request) + ).update_dataset_version._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" - unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_dataset_version._get_unset_required_fields(jsonified_request) + ).update_dataset_version._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -9970,7 +12777,7 @@ def test_create_dataset_version_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = gca_dataset_version.DatasetVersion() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -9982,7 +12789,7 @@ def test_create_dataset_version_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "patch", "query_params": pb_request, } transcode_result["body"] = pb_request @@ -9990,37 +12797,40 @@ def test_create_dataset_version_rest_required_fields( response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gca_dataset_version.DatasetVersion.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.create_dataset_version(request) + response = client.update_dataset_version(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_create_dataset_version_rest_unset_required_fields(): +def test_update_dataset_version_rest_unset_required_fields(): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.create_dataset_version._get_unset_required_fields({}) + unset_fields = transport.update_dataset_version._get_unset_required_fields({}) assert set(unset_fields) == ( - set(()) + set(("updateMask",)) & set( ( - "parent", "datasetVersion", + "updateMask", ) ) ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_create_dataset_version_rest_interceptors(null_interceptor): +def test_update_dataset_version_rest_interceptors(null_interceptor): transport = transports.DatasetServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -10033,16 +12843,14 @@ def test_create_dataset_version_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.DatasetServiceRestInterceptor, "post_create_dataset_version" + transports.DatasetServiceRestInterceptor, "post_update_dataset_version" ) as post, mock.patch.object( - transports.DatasetServiceRestInterceptor, "pre_create_dataset_version" + transports.DatasetServiceRestInterceptor, "pre_update_dataset_version" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = dataset_service.CreateDatasetVersionRequest.pb( - dataset_service.CreateDatasetVersionRequest() + pb_message = dataset_service.UpdateDatasetVersionRequest.pb( + dataset_service.UpdateDatasetVersionRequest() ) transcode.return_value = { "method": "post", @@ -10054,19 +12862,19 @@ def test_create_dataset_version_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = json_format.MessageToJson( - operations_pb2.Operation() + req.return_value._content = gca_dataset_version.DatasetVersion.to_json( + gca_dataset_version.DatasetVersion() ) - request = dataset_service.CreateDatasetVersionRequest() + request = dataset_service.UpdateDatasetVersionRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() + post.return_value = gca_dataset_version.DatasetVersion() - client.create_dataset_version( + client.update_dataset_version( request, metadata=[ ("key", "val"), @@ -10078,8 +12886,8 @@ def test_create_dataset_version_rest_interceptors(null_interceptor): post.assert_called_once() -def test_create_dataset_version_rest_bad_request( - transport: str = "rest", request_type=dataset_service.CreateDatasetVersionRequest +def test_update_dataset_version_rest_bad_request( + transport: str = "rest", request_type=dataset_service.UpdateDatasetVersionRequest ): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -10087,7 +12895,11 @@ def test_create_dataset_version_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/datasets/sample3"} + request_init = { + "dataset_version": { + "name": "projects/sample1/locations/sample2/datasets/sample3/datasetVersions/sample4" + } + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -10099,10 +12911,10 @@ def test_create_dataset_version_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.create_dataset_version(request) + client.update_dataset_version(request) -def test_create_dataset_version_rest_flattened(): +def test_update_dataset_version_rest_flattened(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -10111,41 +12923,45 @@ def test_create_dataset_version_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = gca_dataset_version.DatasetVersion() # get arguments that satisfy an http rule for this method sample_request = { - "parent": "projects/sample1/locations/sample2/datasets/sample3" + "dataset_version": { + "name": "projects/sample1/locations/sample2/datasets/sample3/datasetVersions/sample4" + } } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", dataset_version=gca_dataset_version.DatasetVersion(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gca_dataset_version.DatasetVersion.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.create_dataset_version(**mock_args) + client.update_dataset_version(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*/datasets/*}/datasetVersions" + "%s/v1beta1/{dataset_version.name=projects/*/locations/*/datasets/*/datasetVersions/*}" % client.transport._host, args[1], ) -def test_create_dataset_version_rest_flattened_error(transport: str = "rest"): +def test_update_dataset_version_rest_flattened_error(transport: str = "rest"): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -10154,14 +12970,14 @@ def test_create_dataset_version_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.create_dataset_version( - dataset_service.CreateDatasetVersionRequest(), - parent="parent_value", + client.update_dataset_version( + dataset_service.UpdateDatasetVersionRequest(), dataset_version=gca_dataset_version.DatasetVersion(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) -def test_create_dataset_version_rest_error(): +def test_update_dataset_version_rest_error(): client = DatasetServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -10204,6 +13020,51 @@ def test_delete_dataset_version_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_dataset_version + ] = mock_rpc + + request = {} + client.delete_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_dataset_version_rest_required_fields( request_type=dataset_service.DeleteDatasetVersionRequest, ): @@ -10480,6 +13341,46 @@ def test_get_dataset_version_rest(request_type): assert response.display_name == "display_name_value" +def test_get_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_dataset_version in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_dataset_version + ] = mock_rpc + + request = {} + client.get_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_dataset_version_rest_required_fields( request_type=dataset_service.GetDatasetVersionRequest, ): @@ -10753,6 +13654,47 @@ def test_list_dataset_versions_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_dataset_versions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_dataset_versions + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_dataset_versions + ] = mock_rpc + + request = {} + client.list_dataset_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_dataset_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_dataset_versions_rest_required_fields( request_type=dataset_service.ListDatasetVersionsRequest, ): @@ -11099,6 +14041,51 @@ def test_restore_dataset_version_rest(request_type): assert response.operation.name == "operations/spam" +def test_restore_dataset_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.restore_dataset_version + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.restore_dataset_version + ] = mock_rpc + + request = {} + client.restore_dataset_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.restore_dataset_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_restore_dataset_version_rest_required_fields( request_type=dataset_service.RestoreDatasetVersionRequest, ): @@ -11367,6 +14354,42 @@ def test_list_data_items_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_data_items_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_data_items in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_data_items] = mock_rpc + + request = {} + client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_data_items_rest_required_fields( request_type=dataset_service.ListDataItemsRequest, ): @@ -11716,6 +14739,44 @@ def test_search_data_items_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_data_items_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.search_data_items in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_data_items + ] = mock_rpc + + request = {} + client.search_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_data_items(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_data_items_rest_required_fields( request_type=dataset_service.SearchDataItemsRequest, ): @@ -12020,6 +15081,46 @@ def test_list_saved_queries_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_saved_queries_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_saved_queries in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_saved_queries + ] = mock_rpc + + request = {} + client.list_saved_queries(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_saved_queries(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_saved_queries_rest_required_fields( request_type=dataset_service.ListSavedQueriesRequest, ): @@ -12366,6 +15467,50 @@ def test_delete_saved_query_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_saved_query_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_saved_query in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_saved_query + ] = mock_rpc + + request = {} + client.delete_saved_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_saved_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_saved_query_rest_required_fields( request_type=dataset_service.DeleteSavedQueryRequest, ): @@ -12640,6 +15785,46 @@ def test_get_annotation_spec_rest(request_type): assert response.etag == "etag_value" +def test_get_annotation_spec_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_annotation_spec in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_annotation_spec + ] = mock_rpc + + request = {} + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_annotation_spec(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_annotation_spec_rest_required_fields( request_type=dataset_service.GetAnnotationSpecRequest, ): @@ -12915,6 +16100,44 @@ def test_list_annotations_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_annotations_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DatasetServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_annotations in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_annotations + ] = mock_rpc + + request = {} + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_annotations_rest_required_fields( request_type=dataset_service.ListAnnotationsRequest, ): @@ -13373,6 +16596,7 @@ def test_dataset_service_base_transport(): "import_data", "export_data", "create_dataset_version", + "update_dataset_version", "delete_dataset_version", "get_dataset_version", "list_dataset_versions", @@ -13697,6 +16921,9 @@ def test_dataset_service_client_transport_session_collision(transport_name): session1 = client1.transport.create_dataset_version._session session2 = client2.transport.create_dataset_version._session assert session1 != session2 + session1 = client1.transport.update_dataset_version._session + session2 = client2.transport.update_dataset_version._session + assert session1 != session2 session1 = client1.transport.delete_dataset_version._session session2 = client2.transport.delete_dataset_version._session assert session1 != session2 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_deployment_resource_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_deployment_resource_pool_service.py index 6661868aa9..6404084d31 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_deployment_resource_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_deployment_resource_pool_service.py @@ -1287,6 +1287,9 @@ def test_create_deployment_resource_pool_empty_call(): with mock.patch.object( type(client.transport.create_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_deployment_resource_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1316,6 +1319,9 @@ def test_create_deployment_resource_pool_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.create_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_deployment_resource_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1327,6 +1333,50 @@ def test_create_deployment_resource_pool_non_empty_request_with_auto_populated_f ) +def test_create_deployment_resource_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_deployment_resource_pool + ] = mock_rpc + request = {} + client.create_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_deployment_resource_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1353,6 +1403,56 @@ async def test_create_deployment_resource_pool_empty_call_async(): ) +@pytest.mark.asyncio +async def test_create_deployment_resource_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_deployment_resource_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_deployment_resource_pool + ] = mock_object + + request = {} + await client.create_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_deployment_resource_pool_async( transport: str = "grpc_asyncio", @@ -1625,6 +1725,9 @@ def test_get_deployment_resource_pool_empty_call(): with mock.patch.object( type(client.transport.get_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_deployment_resource_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1653,6 +1756,9 @@ def test_get_deployment_resource_pool_non_empty_request_with_auto_populated_fiel with mock.patch.object( type(client.transport.get_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_deployment_resource_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1663,6 +1769,46 @@ def test_get_deployment_resource_pool_non_empty_request_with_auto_populated_fiel ) +def test_get_deployment_resource_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_deployment_resource_pool + ] = mock_rpc + request = {} + client.get_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_deployment_resource_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1691,6 +1837,52 @@ async def test_get_deployment_resource_pool_empty_call_async(): ) +@pytest.mark.asyncio +async def test_get_deployment_resource_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_deployment_resource_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_deployment_resource_pool + ] = mock_object + + request = {} + await client.get_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_deployment_resource_pool_async( transport: str = "grpc_asyncio", @@ -1936,6 +2128,9 @@ def test_list_deployment_resource_pools_empty_call(): with mock.patch.object( type(client.transport.list_deployment_resource_pools), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_deployment_resource_pools() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1965,6 +2160,9 @@ def test_list_deployment_resource_pools_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.list_deployment_resource_pools), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_deployment_resource_pools(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1976,6 +2174,46 @@ def test_list_deployment_resource_pools_non_empty_request_with_auto_populated_fi ) +def test_list_deployment_resource_pools_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_deployment_resource_pools + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_deployment_resource_pools + ] = mock_rpc + request = {} + client.list_deployment_resource_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_deployment_resource_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_deployment_resource_pools_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2004,6 +2242,52 @@ async def test_list_deployment_resource_pools_empty_call_async(): ) +@pytest.mark.asyncio +async def test_list_deployment_resource_pools_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_deployment_resource_pools + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_deployment_resource_pools + ] = mock_object + + request = {} + await client.list_deployment_resource_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_deployment_resource_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_deployment_resource_pools_async( transport: str = "grpc_asyncio", @@ -2454,6 +2738,9 @@ def test_delete_deployment_resource_pool_empty_call(): with mock.patch.object( type(client.transport.delete_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_deployment_resource_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2482,6 +2769,9 @@ def test_delete_deployment_resource_pool_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.delete_deployment_resource_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_deployment_resource_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2492,6 +2782,50 @@ def test_delete_deployment_resource_pool_non_empty_request_with_auto_populated_f ) +def test_delete_deployment_resource_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_deployment_resource_pool + ] = mock_rpc + request = {} + client.delete_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_deployment_resource_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2518,6 +2852,56 @@ async def test_delete_deployment_resource_pool_empty_call_async(): ) +@pytest.mark.asyncio +async def test_delete_deployment_resource_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_deployment_resource_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_deployment_resource_pool + ] = mock_object + + request = {} + await client.delete_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_deployment_resource_pool_async( transport: str = "grpc_asyncio", @@ -2764,6 +3148,9 @@ def test_query_deployed_models_empty_call(): with mock.patch.object( type(client.transport.query_deployed_models), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_deployed_models() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2790,6 +3177,9 @@ def test_query_deployed_models_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.query_deployed_models), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_deployed_models(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2799,6 +3189,46 @@ def test_query_deployed_models_non_empty_request_with_auto_populated_field(): ) +def test_query_deployed_models_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_deployed_models + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_deployed_models + ] = mock_rpc + request = {} + client.query_deployed_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_deployed_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_deployed_models_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2826,6 +3256,52 @@ async def test_query_deployed_models_empty_call_async(): assert args[0] == deployment_resource_pool_service.QueryDeployedModelsRequest() +@pytest.mark.asyncio +async def test_query_deployed_models_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_deployed_models + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_deployed_models + ] = mock_object + + request = {} + await client.query_deployed_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_deployed_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_query_deployed_models_async( transport: str = "grpc_asyncio", @@ -3264,6 +3740,51 @@ def test_create_deployment_resource_pool_rest(request_type): assert response.operation.name == "operations/spam" +def test_create_deployment_resource_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_deployment_resource_pool + ] = mock_rpc + + request = {} + client.create_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_deployment_resource_pool_rest_required_fields( request_type=deployment_resource_pool_service.CreateDeploymentResourcePoolRequest, ): @@ -3562,6 +4083,47 @@ def test_get_deployment_resource_pool_rest(request_type): assert response.name == "name_value" +def test_get_deployment_resource_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_deployment_resource_pool + ] = mock_rpc + + request = {} + client.get_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_deployment_resource_pool_rest_required_fields( request_type=deployment_resource_pool_service.GetDeploymentResourcePoolRequest, ): @@ -3848,6 +4410,47 @@ def test_list_deployment_resource_pools_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_deployment_resource_pools_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_deployment_resource_pools + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_deployment_resource_pools + ] = mock_rpc + + request = {} + client.list_deployment_resource_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_deployment_resource_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_deployment_resource_pools_rest_required_fields( request_type=deployment_resource_pool_service.ListDeploymentResourcePoolsRequest, ): @@ -4213,6 +4816,51 @@ def test_delete_deployment_resource_pool_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_deployment_resource_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_deployment_resource_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_deployment_resource_pool + ] = mock_rpc + + request = {} + client.delete_deployment_resource_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_deployment_resource_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_deployment_resource_pool_rest_required_fields( request_type=deployment_resource_pool_service.DeleteDeploymentResourcePoolRequest, ): @@ -4496,6 +5144,47 @@ def test_query_deployed_models_rest(request_type): assert response.total_endpoint_count == 2156 +def test_query_deployed_models_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DeploymentResourcePoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_deployed_models + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_deployed_models + ] = mock_rpc + + request = {} + client.query_deployed_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_deployed_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_deployed_models_rest_required_fields( request_type=deployment_resource_pool_service.QueryDeployedModelsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index c160aa6d93..422e63695f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -67,6 +67,7 @@ from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import service_networking from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -1212,6 +1213,9 @@ def test_create_endpoint_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1236,6 +1240,9 @@ def test_create_endpoint_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1245,6 +1252,45 @@ def test_create_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_create_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_endpoint] = mock_rpc + request = {} + client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1266,6 +1312,56 @@ async def test_create_endpoint_empty_call_async(): assert args[0] == endpoint_service.CreateEndpointRequest() +@pytest.mark.asyncio +async def test_create_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_endpoint + ] = mock_object + + request = {} + await client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest @@ -1526,6 +1622,9 @@ def test_get_endpoint_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1549,6 +1648,9 @@ def test_get_endpoint_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1557,6 +1659,41 @@ def test_get_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_get_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_endpoint] = mock_rpc + request = {} + client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1586,6 +1723,52 @@ async def test_get_endpoint_empty_call_async(): assert args[0] == endpoint_service.GetEndpointRequest() +@pytest.mark.asyncio +async def test_get_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_endpoint + ] = mock_object + + request = {} + await client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest @@ -1825,6 +2008,9 @@ def test_list_endpoints_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1850,6 +2036,9 @@ def test_list_endpoints_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_endpoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1860,6 +2049,41 @@ def test_list_endpoints_non_empty_request_with_auto_populated_field(): ) +def test_list_endpoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_endpoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_endpoints] = mock_rpc + request = {} + client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_endpoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1883,6 +2107,52 @@ async def test_list_endpoints_empty_call_async(): assert args[0] == endpoint_service.ListEndpointsRequest() +@pytest.mark.asyncio +async def test_list_endpoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_endpoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_endpoints + ] = mock_object + + request = {} + await client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_endpoints_async( transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest @@ -2316,6 +2586,9 @@ def test_update_endpoint_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2337,12 +2610,50 @@ def test_update_endpoint_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == endpoint_service.UpdateEndpointRequest() +def test_update_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_endpoint] = mock_rpc + request = {} + client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2372,6 +2683,52 @@ async def test_update_endpoint_empty_call_async(): assert args[0] == endpoint_service.UpdateEndpointRequest() +@pytest.mark.asyncio +async def test_update_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_endpoint + ] = mock_object + + request = {} + await client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest @@ -2622,6 +2979,9 @@ def test_delete_endpoint_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2645,6 +3005,9 @@ def test_delete_endpoint_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2653,6 +3016,45 @@ def test_delete_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_delete_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_endpoint] = mock_rpc + request = {} + client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2674,6 +3076,56 @@ async def test_delete_endpoint_empty_call_async(): assert args[0] == endpoint_service.DeleteEndpointRequest() +@pytest.mark.asyncio +async def test_delete_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_endpoint + ] = mock_object + + request = {} + await client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest @@ -2896,6 +3348,9 @@ def test_deploy_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.deploy_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2919,6 +3374,9 @@ def test_deploy_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.deploy_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2927,6 +3385,45 @@ def test_deploy_model_non_empty_request_with_auto_populated_field(): ) +def test_deploy_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.deploy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.deploy_model] = mock_rpc + request = {} + client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.deploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_deploy_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2948,6 +3445,56 @@ async def test_deploy_model_empty_call_async(): assert args[0] == endpoint_service.DeployModelRequest() +@pytest.mark.asyncio +async def test_deploy_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.deploy_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.deploy_model + ] = mock_object + + request = {} + await client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.deploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_deploy_model_async( transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest @@ -3226,6 +3773,9 @@ def test_undeploy_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.undeploy_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3250,6 +3800,9 @@ def test_undeploy_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.undeploy_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3259,6 +3812,45 @@ def test_undeploy_model_non_empty_request_with_auto_populated_field(): ) +def test_undeploy_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.undeploy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.undeploy_model] = mock_rpc + request = {} + client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.undeploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_undeploy_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3280,6 +3872,56 @@ async def test_undeploy_model_empty_call_async(): assert args[0] == endpoint_service.UndeployModelRequest() +@pytest.mark.asyncio +async def test_undeploy_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.undeploy_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.undeploy_model + ] = mock_object + + request = {} + await client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.undeploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_undeploy_model_async( transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest @@ -3526,6 +4168,9 @@ def test_mutate_deployed_model_empty_call(): with mock.patch.object( type(client.transport.mutate_deployed_model), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.mutate_deployed_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3551,6 +4196,9 @@ def test_mutate_deployed_model_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.mutate_deployed_model), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.mutate_deployed_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3559,6 +4207,50 @@ def test_mutate_deployed_model_non_empty_request_with_auto_populated_field(): ) +def test_mutate_deployed_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.mutate_deployed_model + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.mutate_deployed_model + ] = mock_rpc + request = {} + client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.mutate_deployed_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_mutate_deployed_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3582,6 +4274,56 @@ async def test_mutate_deployed_model_empty_call_async(): assert args[0] == endpoint_service.MutateDeployedModelRequest() +@pytest.mark.asyncio +async def test_mutate_deployed_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.mutate_deployed_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.mutate_deployed_model + ] = mock_object + + request = {} + await client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.mutate_deployed_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_mutate_deployed_model_async( transport: str = "grpc_asyncio", @@ -3940,6 +4682,13 @@ def test_create_endpoint_rest(request_type): "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "network": "network_value", "enable_private_service_connect": True, + "private_service_connect_config": { + "enable_private_service_connect": True, + "project_allowlist": [ + "project_allowlist_value1", + "project_allowlist_value2", + ], + }, "model_deployment_monitoring_job": "model_deployment_monitoring_job_value", "predict_request_response_logging_config": { "enabled": True, @@ -4034,6 +4783,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_endpoint] = mock_rpc + + request = {} + client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_endpoint_rest_required_fields( request_type=endpoint_service.CreateEndpointRequest, ): @@ -4328,6 +5117,42 @@ def test_get_endpoint_rest(request_type): ) +def test_get_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_endpoint] = mock_rpc + + request = {} + client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_endpoint_rest_required_fields( request_type=endpoint_service.GetEndpointRequest, ): @@ -4595,6 +5420,42 @@ def test_list_endpoints_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_endpoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_endpoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_endpoints] = mock_rpc + + request = {} + client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_endpoints_rest_required_fields( request_type=endpoint_service.ListEndpointsRequest, ): @@ -5012,6 +5873,13 @@ def test_update_endpoint_rest(request_type): "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "network": "network_value", "enable_private_service_connect": True, + "private_service_connect_config": { + "enable_private_service_connect": True, + "project_allowlist": [ + "project_allowlist_value1", + "project_allowlist_value2", + ], + }, "model_deployment_monitoring_job": "model_deployment_monitoring_job_value", "predict_request_response_logging_config": { "enabled": True, @@ -5126,6 +5994,42 @@ def get_message_fields(field): ) +def test_update_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_endpoint] = mock_rpc + + request = {} + client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_endpoint_rest_required_fields( request_type=endpoint_service.UpdateEndpointRequest, ): @@ -5400,6 +6304,46 @@ def test_delete_endpoint_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_endpoint in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_endpoint] = mock_rpc + + request = {} + client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_endpoint_rest_required_fields( request_type=endpoint_service.DeleteEndpointRequest, ): @@ -5661,6 +6605,46 @@ def test_deploy_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_deploy_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.deploy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.deploy_model] = mock_rpc + + request = {} + client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.deploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_deploy_model_rest_required_fields( request_type=endpoint_service.DeployModelRequest, ): @@ -5947,6 +6931,46 @@ def test_undeploy_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_undeploy_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.undeploy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.undeploy_model] = mock_rpc + + request = {} + client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.undeploy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_undeploy_model_rest_required_fields( request_type=endpoint_service.UndeployModelRequest, ): @@ -6225,6 +7249,51 @@ def test_mutate_deployed_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_mutate_deployed_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.mutate_deployed_model + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.mutate_deployed_model + ] = mock_rpc + + request = {} + client.mutate_deployed_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.mutate_deployed_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_mutate_deployed_model_rest_required_fields( request_type=endpoint_service.MutateDeployedModelRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_evaluation_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_evaluation_service.py index 30f0d77cd6..3b92622e1b 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_evaluation_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_evaluation_service.py @@ -1212,6 +1212,9 @@ def test_evaluate_instances_empty_call(): with mock.patch.object( type(client.transport.evaluate_instances), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.evaluate_instances() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1237,6 +1240,9 @@ def test_evaluate_instances_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.evaluate_instances), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.evaluate_instances(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1245,6 +1251,45 @@ def test_evaluate_instances_non_empty_request_with_auto_populated_field(): ) +def test_evaluate_instances_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.evaluate_instances in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.evaluate_instances + ] = mock_rpc + request = {} + client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.evaluate_instances(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_evaluate_instances_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1268,6 +1313,52 @@ async def test_evaluate_instances_empty_call_async(): assert args[0] == evaluation_service.EvaluateInstancesRequest() +@pytest.mark.asyncio +async def test_evaluate_instances_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.evaluate_instances + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.evaluate_instances + ] = mock_object + + request = {} + await client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.evaluate_instances(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_evaluate_instances_async( transport: str = "grpc_asyncio", @@ -1409,6 +1500,46 @@ def test_evaluate_instances_rest(request_type): assert isinstance(response, evaluation_service.EvaluateInstancesResponse) +def test_evaluate_instances_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.evaluate_instances in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.evaluate_instances + ] = mock_rpc + + request = {} + client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.evaluate_instances(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_evaluate_instances_rest_required_fields( request_type=evaluation_service.EvaluateInstancesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_extension_execution_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_extension_execution_service.py index 68a77d8473..d6981240ca 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_extension_execution_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_extension_execution_service.py @@ -1268,6 +1268,9 @@ def test_execute_extension_empty_call(): with mock.patch.object( type(client.transport.execute_extension), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.execute_extension() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1294,6 +1297,9 @@ def test_execute_extension_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.execute_extension), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.execute_extension(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1303,6 +1309,43 @@ def test_execute_extension_non_empty_request_with_auto_populated_field(): ) +def test_execute_extension_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionExecutionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.execute_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.execute_extension + ] = mock_rpc + request = {} + client.execute_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.execute_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_execute_extension_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1328,6 +1371,52 @@ async def test_execute_extension_empty_call_async(): assert args[0] == extension_execution_service.ExecuteExtensionRequest() +@pytest.mark.asyncio +async def test_execute_extension_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ExtensionExecutionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.execute_extension + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.execute_extension + ] = mock_object + + request = {} + await client.execute_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.execute_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_execute_extension_async( transport: str = "grpc_asyncio", @@ -1577,6 +1666,9 @@ def test_query_extension_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.query_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_extension() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1600,6 +1692,9 @@ def test_query_extension_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.query_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_extension(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1608,6 +1703,41 @@ def test_query_extension_non_empty_request_with_auto_populated_field(): ) +def test_query_extension_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionExecutionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.query_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.query_extension] = mock_rpc + request = {} + client.query_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_extension_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1631,6 +1761,52 @@ async def test_query_extension_empty_call_async(): assert args[0] == extension_execution_service.QueryExtensionRequest() +@pytest.mark.asyncio +async def test_query_extension_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ExtensionExecutionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_extension + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_extension + ] = mock_object + + request = {} + await client.query_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_query_extension_async( transport: str = "grpc_asyncio", @@ -1866,6 +2042,44 @@ def test_execute_extension_rest(request_type): assert response.content == "content_value" +def test_execute_extension_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionExecutionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.execute_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.execute_extension + ] = mock_rpc + + request = {} + client.execute_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.execute_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_execute_extension_rest_required_fields( request_type=extension_execution_service.ExecuteExtensionRequest, ): @@ -2159,6 +2373,42 @@ def test_query_extension_rest(request_type): assert response.failure_message == "failure_message_value" +def test_query_extension_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionExecutionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.query_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.query_extension] = mock_rpc + + request = {} + client.query_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_extension_rest_required_fields( request_type=extension_execution_service.QueryExtensionRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py index 469dc611fa..3338194d73 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py @@ -1264,6 +1264,9 @@ def test_import_extension_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_extension() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1287,6 +1290,9 @@ def test_import_extension_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_extension(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1295,6 +1301,47 @@ def test_import_extension_non_empty_request_with_auto_populated_field(): ) +def test_import_extension_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_extension + ] = mock_rpc + request = {} + client.import_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_extension_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1316,6 +1363,56 @@ async def test_import_extension_empty_call_async(): assert args[0] == extension_registry_service.ImportExtensionRequest() +@pytest.mark.asyncio +async def test_import_extension_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_extension + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_extension + ] = mock_object + + request = {} + await client.import_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.import_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_extension_async( transport: str = "grpc_asyncio", @@ -1558,6 +1655,9 @@ def test_get_extension_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_extension() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1581,6 +1681,9 @@ def test_get_extension_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_extension(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1589,6 +1692,41 @@ def test_get_extension_non_empty_request_with_auto_populated_field(): ) +def test_get_extension_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_extension] = mock_rpc + request = {} + client.get_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_extension_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1615,6 +1753,52 @@ async def test_get_extension_empty_call_async(): assert args[0] == extension_registry_service.GetExtensionRequest() +@pytest.mark.asyncio +async def test_get_extension_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_extension + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_extension + ] = mock_object + + request = {} + await client.get_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_extension_async( transport: str = "grpc_asyncio", @@ -1846,6 +2030,9 @@ def test_list_extensions_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_extensions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_extensions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1872,6 +2059,9 @@ def test_list_extensions_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_extensions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_extensions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1883,6 +2073,41 @@ def test_list_extensions_non_empty_request_with_auto_populated_field(): ) +def test_list_extensions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_extensions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_extensions] = mock_rpc + request = {} + client.list_extensions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_extensions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_extensions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1906,6 +2131,52 @@ async def test_list_extensions_empty_call_async(): assert args[0] == extension_registry_service.ListExtensionsRequest() +@pytest.mark.asyncio +async def test_list_extensions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_extensions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_extensions + ] = mock_object + + request = {} + await client.list_extensions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_extensions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_extensions_async( transport: str = "grpc_asyncio", @@ -2331,6 +2602,9 @@ def test_update_extension_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_extension() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2352,12 +2626,52 @@ def test_update_extension_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_extension(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == extension_registry_service.UpdateExtensionRequest() +def test_update_extension_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_extension + ] = mock_rpc + request = {} + client.update_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_extension_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2384,6 +2698,52 @@ async def test_update_extension_empty_call_async(): assert args[0] == extension_registry_service.UpdateExtensionRequest() +@pytest.mark.asyncio +async def test_update_extension_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_extension + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_extension + ] = mock_object + + request = {} + await client.update_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_extension_async( transport: str = "grpc_asyncio", @@ -2626,6 +2986,9 @@ def test_delete_extension_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_extension() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2649,6 +3012,9 @@ def test_delete_extension_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_extension), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_extension(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2657,6 +3023,47 @@ def test_delete_extension_non_empty_request_with_auto_populated_field(): ) +def test_delete_extension_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_extension + ] = mock_rpc + request = {} + client.delete_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_extension_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2678,6 +3085,56 @@ async def test_delete_extension_empty_call_async(): assert args[0] == extension_registry_service.DeleteExtensionRequest() +@pytest.mark.asyncio +async def test_delete_extension_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_extension + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_extension + ] = mock_object + + request = {} + await client.delete_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_extension_async( transport: str = "grpc_asyncio", @@ -2955,7 +3412,8 @@ def test_import_extension_rest(request_type): "file_output_gcs_bucket": "file_output_gcs_bucket_value", }, "vertex_ai_search_runtime_config": { - "serving_config_name": "serving_config_name_value" + "serving_config_name": "serving_config_name_value", + "app_id": "app_id_value", }, "default_params": {}, }, @@ -3066,6 +3524,48 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_import_extension_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_extension + ] = mock_rpc + + request = {} + client.import_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_import_extension_rest_required_fields( request_type=extension_registry_service.ImportExtensionRequest, ): @@ -3348,6 +3848,42 @@ def test_get_extension_rest(request_type): assert response.etag == "etag_value" +def test_get_extension_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_extension] = mock_rpc + + request = {} + client.get_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_extension_rest_required_fields( request_type=extension_registry_service.GetExtensionRequest, ): @@ -3617,6 +4153,42 @@ def test_list_extensions_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_extensions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_extensions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_extensions] = mock_rpc + + request = {} + client.list_extensions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_extensions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_extensions_rest_required_fields( request_type=extension_registry_service.ListExtensionsRequest, ): @@ -4027,7 +4599,8 @@ def test_update_extension_rest(request_type): "file_output_gcs_bucket": "file_output_gcs_bucket_value", }, "vertex_ai_search_runtime_config": { - "serving_config_name": "serving_config_name_value" + "serving_config_name": "serving_config_name_value", + "app_id": "app_id_value", }, "default_params": {}, }, @@ -4149,6 +4722,44 @@ def get_message_fields(field): assert response.etag == "etag_value" +def test_update_extension_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_extension + ] = mock_rpc + + request = {} + client.update_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_extension_rest_required_fields( request_type=extension_registry_service.UpdateExtensionRequest, ): @@ -4426,6 +5037,48 @@ def test_delete_extension_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_extension_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ExtensionRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_extension in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_extension + ] = mock_rpc + + request = {} + client.delete_extension(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_extension(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_extension_rest_required_fields( request_type=extension_registry_service.DeleteExtensionRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py index 98ae39741f..608d96dfb3 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py @@ -1294,6 +1294,9 @@ def test_create_feature_online_store_empty_call(): with mock.patch.object( type(client.transport.create_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_online_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1323,6 +1326,9 @@ def test_create_feature_online_store_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.create_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_online_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1334,6 +1340,50 @@ def test_create_feature_online_store_non_empty_request_with_auto_populated_field ) +def test_create_feature_online_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_online_store + ] = mock_rpc + request = {} + client.create_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_online_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1360,6 +1410,56 @@ async def test_create_feature_online_store_empty_call_async(): ) +@pytest.mark.asyncio +async def test_create_feature_online_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature_online_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature_online_store + ] = mock_object + + request = {} + await client.create_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_online_store_async( transport: str = "grpc_asyncio", @@ -1660,6 +1760,9 @@ def test_get_feature_online_store_empty_call(): with mock.patch.object( type(client.transport.get_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_online_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1687,6 +1790,9 @@ def test_get_feature_online_store_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_online_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1697,6 +1803,46 @@ def test_get_feature_online_store_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_online_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_online_store + ] = mock_rpc + request = {} + client.get_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_online_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1726,6 +1872,52 @@ async def test_get_feature_online_store_empty_call_async(): ) +@pytest.mark.asyncio +async def test_get_feature_online_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature_online_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature_online_store + ] = mock_object + + request = {} + await client.get_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_online_store_async( transport: str = "grpc_asyncio", @@ -1975,6 +2167,9 @@ def test_list_feature_online_stores_empty_call(): with mock.patch.object( type(client.transport.list_feature_online_stores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_online_stores() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2006,6 +2201,9 @@ def test_list_feature_online_stores_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.list_feature_online_stores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_online_stores(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2019,6 +2217,46 @@ def test_list_feature_online_stores_non_empty_request_with_auto_populated_field( ) +def test_list_feature_online_stores_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_online_stores + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_online_stores + ] = mock_rpc + request = {} + client.list_feature_online_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_online_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_online_stores_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2047,6 +2285,52 @@ async def test_list_feature_online_stores_empty_call_async(): ) +@pytest.mark.asyncio +async def test_list_feature_online_stores_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_feature_online_stores + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_feature_online_stores + ] = mock_object + + request = {} + await client.list_feature_online_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_feature_online_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_online_stores_async( transport: str = "grpc_asyncio", @@ -2495,6 +2779,9 @@ def test_update_feature_online_store_empty_call(): with mock.patch.object( type(client.transport.update_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_online_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2521,6 +2808,9 @@ def test_update_feature_online_store_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.update_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_online_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2530,6 +2820,50 @@ def test_update_feature_online_store_non_empty_request_with_auto_populated_field ) +def test_update_feature_online_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_online_store + ] = mock_rpc + request = {} + client.update_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_online_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2556,6 +2890,56 @@ async def test_update_feature_online_store_empty_call_async(): ) +@pytest.mark.asyncio +async def test_update_feature_online_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature_online_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature_online_store + ] = mock_object + + request = {} + await client.update_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_online_store_async( transport: str = "grpc_asyncio", @@ -2839,6 +3223,9 @@ def test_delete_feature_online_store_empty_call(): with mock.patch.object( type(client.transport.delete_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_online_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2867,6 +3254,9 @@ def test_delete_feature_online_store_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.delete_feature_online_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_online_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2877,6 +3267,50 @@ def test_delete_feature_online_store_non_empty_request_with_auto_populated_field ) +def test_delete_feature_online_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_online_store + ] = mock_rpc + request = {} + client.delete_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_online_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2903,6 +3337,56 @@ async def test_delete_feature_online_store_empty_call_async(): ) +@pytest.mark.asyncio +async def test_delete_feature_online_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature_online_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature_online_store + ] = mock_object + + request = {} + await client.delete_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_online_store_async( transport: str = "grpc_asyncio", @@ -3150,6 +3634,9 @@ def test_create_feature_view_empty_call(): with mock.patch.object( type(client.transport.create_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3176,6 +3663,9 @@ def test_create_feature_view_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3185,6 +3675,49 @@ def test_create_feature_view_non_empty_request_with_auto_populated_field(): ) +def test_create_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_view + ] = mock_rpc + request = {} + client.create_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_view_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3208,6 +3741,56 @@ async def test_create_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.CreateFeatureViewRequest() +@pytest.mark.asyncio +async def test_create_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature_view + ] = mock_object + + request = {} + await client.create_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_view_async( transport: str = "grpc_asyncio", @@ -3497,6 +4080,9 @@ def test_get_feature_view_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature_view), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3520,6 +4106,9 @@ def test_get_feature_view_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature_view), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3528,12 +4117,49 @@ def test_get_feature_view_non_empty_request_with_auto_populated_field(): ) -@pytest.mark.asyncio -async def test_get_feature_view_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = FeatureOnlineStoreAdminServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), +def test_get_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature_view in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_view + ] = mock_rpc + request = {} + client.get_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_feature_view_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", ) @@ -3554,6 +4180,52 @@ async def test_get_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.GetFeatureViewRequest() +@pytest.mark.asyncio +async def test_get_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature_view + ] = mock_object + + request = {} + await client.get_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_view_async( transport: str = "grpc_asyncio", @@ -3796,6 +4468,9 @@ def test_list_feature_views_empty_call(): with mock.patch.object( type(client.transport.list_feature_views), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_views() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3824,6 +4499,9 @@ def test_list_feature_views_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_feature_views), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_views(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3835,6 +4513,45 @@ def test_list_feature_views_non_empty_request_with_auto_populated_field(): ) +def test_list_feature_views_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_views in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_views + ] = mock_rpc + request = {} + client.list_feature_views(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_views(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_views_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3860,6 +4577,52 @@ async def test_list_feature_views_empty_call_async(): assert args[0] == feature_online_store_admin_service.ListFeatureViewsRequest() +@pytest.mark.asyncio +async def test_list_feature_views_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_feature_views + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_feature_views + ] = mock_object + + request = {} + await client.list_feature_views(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_feature_views(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_views_async( transport: str = "grpc_asyncio", @@ -4304,6 +5067,9 @@ def test_update_feature_view_empty_call(): with mock.patch.object( type(client.transport.update_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4327,12 +5093,58 @@ def test_update_feature_view_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == feature_online_store_admin_service.UpdateFeatureViewRequest() +def test_update_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_view + ] = mock_rpc + request = {} + client.update_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_view_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4356,6 +5168,56 @@ async def test_update_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.UpdateFeatureViewRequest() +@pytest.mark.asyncio +async def test_update_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature_view + ] = mock_object + + request = {} + await client.update_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_view_async( transport: str = "grpc_asyncio", @@ -4627,6 +5489,9 @@ def test_delete_feature_view_empty_call(): with mock.patch.object( type(client.transport.delete_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4652,6 +5517,9 @@ def test_delete_feature_view_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4660,6 +5528,49 @@ def test_delete_feature_view_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_view + ] = mock_rpc + request = {} + client.delete_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_view_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4683,6 +5594,56 @@ async def test_delete_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.DeleteFeatureViewRequest() +@pytest.mark.asyncio +async def test_delete_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature_view + ] = mock_object + + request = {} + await client.delete_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_view_async( transport: str = "grpc_asyncio", @@ -4925,6 +5886,9 @@ def test_sync_feature_view_empty_call(): with mock.patch.object( type(client.transport.sync_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.sync_feature_view() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4950,6 +5914,9 @@ def test_sync_feature_view_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.sync_feature_view), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.sync_feature_view(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4958,6 +5925,43 @@ def test_sync_feature_view_non_empty_request_with_auto_populated_field(): ) +def test_sync_feature_view_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.sync_feature_view in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.sync_feature_view + ] = mock_rpc + request = {} + client.sync_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.sync_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_sync_feature_view_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4983,6 +5987,52 @@ async def test_sync_feature_view_empty_call_async(): assert args[0] == feature_online_store_admin_service.SyncFeatureViewRequest() +@pytest.mark.asyncio +async def test_sync_feature_view_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.sync_feature_view + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.sync_feature_view + ] = mock_object + + request = {} + await client.sync_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.sync_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_sync_feature_view_async( transport: str = "grpc_asyncio", @@ -5228,6 +6278,9 @@ def test_get_feature_view_sync_empty_call(): with mock.patch.object( type(client.transport.get_feature_view_sync), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_view_sync() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5253,6 +6306,9 @@ def test_get_feature_view_sync_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_feature_view_sync), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_view_sync(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5261,6 +6317,46 @@ def test_get_feature_view_sync_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_view_sync_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_feature_view_sync + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_view_sync + ] = mock_rpc + request = {} + client.get_feature_view_sync(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_view_sync(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_view_sync_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5286,6 +6382,52 @@ async def test_get_feature_view_sync_empty_call_async(): assert args[0] == feature_online_store_admin_service.GetFeatureViewSyncRequest() +@pytest.mark.asyncio +async def test_get_feature_view_sync_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature_view_sync + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature_view_sync + ] = mock_object + + request = {} + await client.get_feature_view_sync(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature_view_sync(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_view_sync_async( transport: str = "grpc_asyncio", @@ -5531,6 +6673,9 @@ def test_list_feature_view_syncs_empty_call(): with mock.patch.object( type(client.transport.list_feature_view_syncs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_view_syncs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5561,6 +6706,9 @@ def test_list_feature_view_syncs_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_feature_view_syncs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_view_syncs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5574,6 +6722,46 @@ def test_list_feature_view_syncs_non_empty_request_with_auto_populated_field(): ) +def test_list_feature_view_syncs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_view_syncs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_view_syncs + ] = mock_rpc + request = {} + client.list_feature_view_syncs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_view_syncs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_view_syncs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5593,13 +6781,59 @@ async def test_list_feature_view_syncs_empty_call_async(): next_page_token="next_page_token_value", ) ) - response = await client.list_feature_view_syncs() - call.assert_called() - _, args, _ = call.mock_calls[0] + response = await client.list_feature_view_syncs() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert ( + args[0] == feature_online_store_admin_service.ListFeatureViewSyncsRequest() + ) + + +@pytest.mark.asyncio +async def test_list_feature_view_syncs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached assert ( - args[0] == feature_online_store_admin_service.ListFeatureViewSyncsRequest() + client._client._transport.list_feature_view_syncs + in client._client._transport._wrapped_methods ) + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_feature_view_syncs + ] = mock_object + + request = {} + await client.list_feature_view_syncs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_feature_view_syncs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + @pytest.mark.asyncio async def test_list_feature_view_syncs_async( @@ -6134,6 +7368,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_online_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_online_store + ] = mock_rpc + + request = {} + client.create_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_online_store_rest_required_fields( request_type=feature_online_store_admin_service.CreateFeatureOnlineStoreRequest, ): @@ -6452,6 +7731,47 @@ def test_get_feature_online_store_rest(request_type): assert response.state == feature_online_store.FeatureOnlineStore.State.STABLE +def test_get_feature_online_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_online_store + ] = mock_rpc + + request = {} + client.get_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_online_store_rest_required_fields( request_type=feature_online_store_admin_service.GetFeatureOnlineStoreRequest, ): @@ -6732,6 +8052,47 @@ def test_list_feature_online_stores_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_feature_online_stores_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_online_stores + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_online_stores + ] = mock_rpc + + request = {} + client.list_feature_online_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_online_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_feature_online_stores_rest_required_fields( request_type=feature_online_store_admin_service.ListFeatureOnlineStoresRequest, ): @@ -7199,6 +8560,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_feature_online_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_online_store + ] = mock_rpc + + request = {} + client.update_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_online_store_rest_required_fields( request_type=feature_online_store_admin_service.UpdateFeatureOnlineStoreRequest, ): @@ -7485,6 +8891,51 @@ def test_delete_feature_online_store_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_online_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_online_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_online_store + ] = mock_rpc + + request = {} + client.delete_feature_online_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_online_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_online_store_rest_required_fields( request_type=feature_online_store_admin_service.DeleteFeatureOnlineStoreRequest, ): @@ -7874,6 +9325,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_view + ] = mock_rpc + + request = {} + client.create_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.CreateFeatureViewRequest, ): @@ -8202,6 +9697,44 @@ def test_get_feature_view_rest(request_type): assert response.service_account_email == "service_account_email_value" +def test_get_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature_view in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_view + ] = mock_rpc + + request = {} + client.get_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.GetFeatureViewRequest, ): @@ -8479,6 +10012,46 @@ def test_list_feature_views_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_feature_views_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_views in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_views + ] = mock_rpc + + request = {} + client.list_feature_views(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_views(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_feature_views_rest_required_fields( request_type=feature_online_store_admin_service.ListFeatureViewsRequest, ): @@ -8956,6 +10529,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_view + ] = mock_rpc + + request = {} + client.update_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.UpdateFeatureViewRequest, ): @@ -9236,6 +10853,50 @@ def test_delete_feature_view_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_view in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_view + ] = mock_rpc + + request = {} + client.delete_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.DeleteFeatureViewRequest, ): @@ -9513,6 +11174,44 @@ def test_sync_feature_view_rest(request_type): assert response.feature_view_sync == "feature_view_sync_value" +def test_sync_feature_view_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.sync_feature_view in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.sync_feature_view + ] = mock_rpc + + request = {} + client.sync_feature_view(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.sync_feature_view(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_sync_feature_view_rest_required_fields( request_type=feature_online_store_admin_service.SyncFeatureViewRequest, ): @@ -9798,6 +11497,47 @@ def test_get_feature_view_sync_rest(request_type): assert response.name == "name_value" +def test_get_feature_view_sync_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_feature_view_sync + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_view_sync + ] = mock_rpc + + request = {} + client.get_feature_view_sync(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_view_sync(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_view_sync_rest_required_fields( request_type=feature_online_store_admin_service.GetFeatureViewSyncRequest, ): @@ -10078,6 +11818,47 @@ def test_list_feature_view_syncs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_feature_view_syncs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreAdminServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_view_syncs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_view_syncs + ] = mock_rpc + + request = {} + client.list_feature_view_syncs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_view_syncs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_feature_view_syncs_rest_required_fields( request_type=feature_online_store_admin_service.ListFeatureViewSyncsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_service.py index 1b523a1c41..846724ce11 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_service.py @@ -1262,6 +1262,9 @@ def test_fetch_feature_values_empty_call(): with mock.patch.object( type(client.transport.fetch_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.fetch_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1288,6 +1291,9 @@ def test_fetch_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.fetch_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.fetch_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1297,6 +1303,45 @@ def test_fetch_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_fetch_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.fetch_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.fetch_feature_values + ] = mock_rpc + request = {} + client.fetch_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.fetch_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_fetch_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1320,6 +1365,52 @@ async def test_fetch_feature_values_empty_call_async(): assert args[0] == feature_online_store_service.FetchFeatureValuesRequest() +@pytest.mark.asyncio +async def test_fetch_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.fetch_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.fetch_feature_values + ] = mock_object + + request = {} + await client.fetch_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.fetch_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_fetch_feature_values_async( transport: str = "grpc_asyncio", @@ -1560,6 +1651,92 @@ def test_streaming_fetch_feature_values(request_type, transport: str = "grpc"): ) +def test_streaming_fetch_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.streaming_fetch_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_fetch_feature_values + ] = mock_rpc + request = [{}] + client.streaming_fetch_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_fetch_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_streaming_fetch_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.streaming_fetch_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.streaming_fetch_feature_values + ] = mock_object + + request = [{}] + await client.streaming_fetch_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.streaming_fetch_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_fetch_feature_values_async( transport: str = "grpc_asyncio", @@ -1654,6 +1831,9 @@ def test_search_nearest_entities_empty_call(): with mock.patch.object( type(client.transport.search_nearest_entities), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_nearest_entities() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1679,6 +1859,9 @@ def test_search_nearest_entities_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.search_nearest_entities), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_nearest_entities(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1687,6 +1870,46 @@ def test_search_nearest_entities_non_empty_request_with_auto_populated_field(): ) +def test_search_nearest_entities_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_nearest_entities + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_nearest_entities + ] = mock_rpc + request = {} + client.search_nearest_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_nearest_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_nearest_entities_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1710,6 +1933,52 @@ async def test_search_nearest_entities_empty_call_async(): assert args[0] == feature_online_store_service.SearchNearestEntitiesRequest() +@pytest.mark.asyncio +async def test_search_nearest_entities_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_nearest_entities + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_nearest_entities + ] = mock_object + + request = {} + await client.search_nearest_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_nearest_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_nearest_entities_async( transport: str = "grpc_asyncio", @@ -1857,6 +2126,46 @@ def test_fetch_feature_values_rest(request_type): assert isinstance(response, feature_online_store_service.FetchFeatureValuesResponse) +def test_fetch_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.fetch_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.fetch_feature_values + ] = mock_rpc + + request = {} + client.fetch_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.fetch_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_fetch_feature_values_rest_required_fields( request_type=feature_online_store_service.FetchFeatureValuesRequest, ): @@ -2152,6 +2461,47 @@ def test_search_nearest_entities_rest(request_type): ) +def test_search_nearest_entities_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureOnlineStoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_nearest_entities + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_nearest_entities + ] = mock_rpc + + request = {} + client.search_nearest_entities(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_nearest_entities(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_nearest_entities_rest_required_fields( request_type=feature_online_store_service.SearchNearestEntitiesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_feature_registry_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_feature_registry_service.py index 08b94ba5ae..80e6e4116e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_feature_registry_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_feature_registry_service.py @@ -1264,6 +1264,9 @@ def test_create_feature_group_empty_call(): with mock.patch.object( type(client.transport.create_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_group() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1290,6 +1293,9 @@ def test_create_feature_group_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature_group(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1299,6 +1305,49 @@ def test_create_feature_group_non_empty_request_with_auto_populated_field(): ) +def test_create_feature_group_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_group + ] = mock_rpc + request = {} + client.create_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_group_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1322,6 +1371,56 @@ async def test_create_feature_group_empty_call_async(): assert args[0] == feature_registry_service.CreateFeatureGroupRequest() +@pytest.mark.asyncio +async def test_create_feature_group_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature_group + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature_group + ] = mock_object + + request = {} + await client.create_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_group_async( transport: str = "grpc_asyncio", @@ -1610,6 +1709,9 @@ def test_get_feature_group_empty_call(): with mock.patch.object( type(client.transport.get_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_group() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1635,6 +1737,9 @@ def test_get_feature_group_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature_group(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1643,6 +1748,43 @@ def test_get_feature_group_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_group_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature_group in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_group + ] = mock_rpc + request = {} + client.get_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_group_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1670,6 +1812,52 @@ async def test_get_feature_group_empty_call_async(): assert args[0] == feature_registry_service.GetFeatureGroupRequest() +@pytest.mark.asyncio +async def test_get_feature_group_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature_group + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature_group + ] = mock_object + + request = {} + await client.get_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_group_async( transport: str = "grpc_asyncio", @@ -1917,6 +2105,9 @@ def test_list_feature_groups_empty_call(): with mock.patch.object( type(client.transport.list_feature_groups), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_groups() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1945,6 +2136,9 @@ def test_list_feature_groups_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_feature_groups), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_feature_groups(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1956,6 +2150,45 @@ def test_list_feature_groups_non_empty_request_with_auto_populated_field(): ) +def test_list_feature_groups_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_groups in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_groups + ] = mock_rpc + request = {} + client.list_feature_groups(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_groups(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_groups_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1981,6 +2214,52 @@ async def test_list_feature_groups_empty_call_async(): assert args[0] == feature_registry_service.ListFeatureGroupsRequest() +@pytest.mark.asyncio +async def test_list_feature_groups_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_feature_groups + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_feature_groups + ] = mock_object + + request = {} + await client.list_feature_groups(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_feature_groups(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_feature_groups_async( transport: str = "grpc_asyncio", @@ -2419,6 +2698,9 @@ def test_update_feature_group_empty_call(): with mock.patch.object( type(client.transport.update_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_group() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2442,12 +2724,58 @@ def test_update_feature_group_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature_group(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == feature_registry_service.UpdateFeatureGroupRequest() +def test_update_feature_group_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_group + ] = mock_rpc + request = {} + client.update_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_group_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2471,6 +2799,56 @@ async def test_update_feature_group_empty_call_async(): assert args[0] == feature_registry_service.UpdateFeatureGroupRequest() +@pytest.mark.asyncio +async def test_update_feature_group_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature_group + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature_group + ] = mock_object + + request = {} + await client.update_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_group_async( transport: str = "grpc_asyncio", @@ -2742,6 +3120,9 @@ def test_delete_feature_group_empty_call(): with mock.patch.object( type(client.transport.delete_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_group() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2767,6 +3148,9 @@ def test_delete_feature_group_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_feature_group), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_group(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2775,6 +3159,49 @@ def test_delete_feature_group_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_group_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_group + ] = mock_rpc + request = {} + client.delete_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_group_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2798,6 +3225,56 @@ async def test_delete_feature_group_empty_call_async(): assert args[0] == feature_registry_service.DeleteFeatureGroupRequest() +@pytest.mark.asyncio +async def test_delete_feature_group_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature_group + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature_group + ] = mock_object + + request = {} + await client.delete_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_group_async( transport: str = "grpc_asyncio", @@ -3041,6 +3518,9 @@ def test_create_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3065,6 +3545,9 @@ def test_create_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3074,6 +3557,45 @@ def test_create_feature_non_empty_request_with_auto_populated_field(): ) +def test_create_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_feature] = mock_rpc + request = {} + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3095,6 +3617,56 @@ async def test_create_feature_empty_call_async(): assert args[0] == featurestore_service.CreateFeatureRequest() +@pytest.mark.asyncio +async def test_create_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature + ] = mock_object + + request = {} + await client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_async( transport: str = "grpc_asyncio", @@ -3353,6 +3925,9 @@ def test_get_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3376,6 +3951,9 @@ def test_get_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3384,6 +3962,41 @@ def test_get_feature_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_feature] = mock_rpc + request = {} + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3413,6 +4026,52 @@ async def test_get_feature_empty_call_async(): assert args[0] == featurestore_service.GetFeatureRequest() +@pytest.mark.asyncio +async def test_get_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature + ] = mock_object + + request = {} + await client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_async( transport: str = "grpc_asyncio", request_type=featurestore_service.GetFeatureRequest @@ -3649,6 +4308,9 @@ def test_list_features_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_features() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3675,6 +4337,9 @@ def test_list_features_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_features(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3686,6 +4351,41 @@ def test_list_features_non_empty_request_with_auto_populated_field(): ) +def test_list_features_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_features] = mock_rpc + request = {} + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_features_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3709,6 +4409,52 @@ async def test_list_features_empty_call_async(): assert args[0] == featurestore_service.ListFeaturesRequest() +@pytest.mark.asyncio +async def test_list_features_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_features + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_features + ] = mock_object + + request = {} + await client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_features_async( transport: str = "grpc_asyncio", @@ -4125,6 +4871,9 @@ def test_update_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4146,12 +4895,54 @@ def test_update_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateFeatureRequest() +def test_update_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_feature] = mock_rpc + request = {} + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4173,6 +4964,56 @@ async def test_update_feature_empty_call_async(): assert args[0] == featurestore_service.UpdateFeatureRequest() +@pytest.mark.asyncio +async def test_update_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature + ] = mock_object + + request = {} + await client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_async( transport: str = "grpc_asyncio", @@ -4406,6 +5247,9 @@ def test_delete_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4429,6 +5273,9 @@ def test_delete_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4437,6 +5284,45 @@ def test_delete_feature_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_feature] = mock_rpc + request = {} + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4458,6 +5344,56 @@ async def test_delete_feature_empty_call_async(): assert args[0] == featurestore_service.DeleteFeatureRequest() +@pytest.mark.asyncio +async def test_delete_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature + ] = mock_object + + request = {} + await client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_async( transport: str = "grpc_asyncio", @@ -4757,6 +5693,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_group_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_feature_group + ] = mock_rpc + + request = {} + client.create_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_group_rest_required_fields( request_type=feature_registry_service.CreateFeatureGroupRequest, ): @@ -5062,6 +6042,44 @@ def test_get_feature_group_rest(request_type): assert response.description == "description_value" +def test_get_feature_group_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature_group in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_feature_group + ] = mock_rpc + + request = {} + client.get_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_group_rest_required_fields( request_type=feature_registry_service.GetFeatureGroupRequest, ): @@ -5334,6 +6352,46 @@ def test_list_feature_groups_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_feature_groups_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_feature_groups in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_feature_groups + ] = mock_rpc + + request = {} + client.list_feature_groups(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_feature_groups(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_feature_groups_rest_required_fields( request_type=feature_registry_service.ListFeatureGroupsRequest, ): @@ -5768,6 +6826,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_feature_group_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_feature_group + ] = mock_rpc + + request = {} + client.update_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_group_rest_required_fields( request_type=feature_registry_service.UpdateFeatureGroupRequest, ): @@ -6044,6 +7146,50 @@ def test_delete_feature_group_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_group_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_group in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_group + ] = mock_rpc + + request = {} + client.delete_feature_group(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_group(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_group_rest_required_fields( request_type=feature_registry_service.DeleteFeatureGroupRequest, ): @@ -6414,6 +7560,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_feature] = mock_rpc + + request = {} + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_rest_required_fields( request_type=featurestore_service.CreateFeatureRequest, ): @@ -6724,6 +7910,42 @@ def test_get_feature_rest(request_type): assert response.point_of_contact == "point_of_contact_value" +def test_get_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_feature] = mock_rpc + + request = {} + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_rest_required_fields( request_type=featurestore_service.GetFeatureRequest, ): @@ -6995,6 +8217,42 @@ def test_list_features_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_features_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_features] = mock_rpc + + request = {} + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_features_rest_required_fields( request_type=featurestore_service.ListFeaturesRequest, ): @@ -7449,6 +8707,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_feature] = mock_rpc + + request = {} + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_rest_required_fields( request_type=featurestore_service.UpdateFeatureRequest, ): @@ -7718,6 +9016,46 @@ def test_delete_feature_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeatureRegistryServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_feature] = mock_rpc + + request = {} + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_rest_required_fields( request_type=featurestore_service.DeleteFeatureRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py index 7a24764a45..8b941bb0e4 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py @@ -1289,6 +1289,9 @@ def test_read_feature_values_empty_call(): with mock.patch.object( type(client.transport.read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1315,6 +1318,9 @@ def test_read_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1324,6 +1330,45 @@ def test_read_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_read_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_feature_values + ] = mock_rpc + request = {} + client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1347,6 +1392,52 @@ async def test_read_feature_values_empty_call_async(): assert args[0] == featurestore_online_service.ReadFeatureValuesRequest() +@pytest.mark.asyncio +async def test_read_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_feature_values + ] = mock_object + + request = {} + await client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_feature_values_async( transport: str = "grpc_asyncio", @@ -1589,6 +1680,9 @@ def test_streaming_read_feature_values_empty_call(): with mock.patch.object( type(client.transport.streaming_read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.streaming_read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1616,6 +1710,9 @@ def test_streaming_read_feature_values_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.streaming_read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.streaming_read_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1624,6 +1721,46 @@ def test_streaming_read_feature_values_non_empty_request_with_auto_populated_fie ) +def test_streaming_read_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.streaming_read_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_read_feature_values + ] = mock_rpc + request = {} + client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_read_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1650,6 +1787,52 @@ async def test_streaming_read_feature_values_empty_call_async(): ) +@pytest.mark.asyncio +async def test_streaming_read_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.streaming_read_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.streaming_read_feature_values + ] = mock_object + + request = {} + await client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.streaming_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_read_feature_values_async( transport: str = "grpc_asyncio", @@ -1894,6 +2077,9 @@ def test_write_feature_values_empty_call(): with mock.patch.object( type(client.transport.write_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1919,6 +2105,9 @@ def test_write_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.write_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1927,6 +2116,45 @@ def test_write_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_write_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_feature_values + ] = mock_rpc + request = {} + client.write_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_write_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1950,6 +2178,52 @@ async def test_write_feature_values_empty_call_async(): assert args[0] == featurestore_online_service.WriteFeatureValuesRequest() +@pytest.mark.asyncio +async def test_write_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.write_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.write_feature_values + ] = mock_object + + request = {} + await client.write_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.write_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_write_feature_values_async( transport: str = "grpc_asyncio", @@ -2215,6 +2489,46 @@ def test_read_feature_values_rest(request_type): assert isinstance(response, featurestore_online_service.ReadFeatureValuesResponse) +def test_read_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_feature_values + ] = mock_rpc + + request = {} + client.read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_feature_values_rest_required_fields( request_type=featurestore_online_service.ReadFeatureValuesRequest, ): @@ -2517,6 +2831,47 @@ def test_streaming_read_feature_values_rest(request_type): assert isinstance(response, featurestore_online_service.ReadFeatureValuesResponse) +def test_streaming_read_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.streaming_read_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_read_feature_values + ] = mock_rpc + + request = {} + client.streaming_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_streaming_read_feature_values_rest_required_fields( request_type=featurestore_online_service.StreamingReadFeatureValuesRequest, ): @@ -2821,6 +3176,46 @@ def test_write_feature_values_rest(request_type): assert isinstance(response, featurestore_online_service.WriteFeatureValuesResponse) +def test_write_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreOnlineServingServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_feature_values in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_feature_values + ] = mock_rpc + + request = {} + client.write_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_write_feature_values_rest_required_fields( request_type=featurestore_online_service.WriteFeatureValuesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py index adddbea265..316c313671 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py @@ -1264,6 +1264,9 @@ def test_create_featurestore_empty_call(): with mock.patch.object( type(client.transport.create_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1290,6 +1293,9 @@ def test_create_featurestore_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_featurestore(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1299,6 +1305,49 @@ def test_create_featurestore_non_empty_request_with_auto_populated_field(): ) +def test_create_featurestore_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_featurestore + ] = mock_rpc + request = {} + client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_featurestore_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1322,6 +1371,56 @@ async def test_create_featurestore_empty_call_async(): assert args[0] == featurestore_service.CreateFeaturestoreRequest() +@pytest.mark.asyncio +async def test_create_featurestore_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_featurestore + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_featurestore + ] = mock_object + + request = {} + await client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_featurestore_async( transport: str = "grpc_asyncio", @@ -1584,6 +1683,9 @@ def test_get_featurestore_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1607,6 +1709,9 @@ def test_get_featurestore_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_featurestore), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_featurestore(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1615,6 +1720,43 @@ def test_get_featurestore_non_empty_request_with_auto_populated_field(): ) +def test_get_featurestore_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_featurestore in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_featurestore + ] = mock_rpc + request = {} + client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_featurestore_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1641,6 +1783,52 @@ async def test_get_featurestore_empty_call_async(): assert args[0] == featurestore_service.GetFeaturestoreRequest() +@pytest.mark.asyncio +async def test_get_featurestore_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_featurestore + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_featurestore + ] = mock_object + + request = {} + await client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_featurestore_async( transport: str = "grpc_asyncio", @@ -1880,6 +2068,9 @@ def test_list_featurestores_empty_call(): with mock.patch.object( type(client.transport.list_featurestores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_featurestores() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1908,6 +2099,9 @@ def test_list_featurestores_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_featurestores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_featurestores(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1919,6 +2113,45 @@ def test_list_featurestores_non_empty_request_with_auto_populated_field(): ) +def test_list_featurestores_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_featurestores in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_featurestores + ] = mock_rpc + request = {} + client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_featurestores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_featurestores_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1944,6 +2177,52 @@ async def test_list_featurestores_empty_call_async(): assert args[0] == featurestore_service.ListFeaturestoresRequest() +@pytest.mark.asyncio +async def test_list_featurestores_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_featurestores + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_featurestores + ] = mock_object + + request = {} + await client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_featurestores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_featurestores_async( transport: str = "grpc_asyncio", @@ -2382,6 +2661,9 @@ def test_update_featurestore_empty_call(): with mock.patch.object( type(client.transport.update_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2405,12 +2687,58 @@ def test_update_featurestore_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_featurestore(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateFeaturestoreRequest() +def test_update_featurestore_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_featurestore + ] = mock_rpc + request = {} + client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_featurestore_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2434,6 +2762,56 @@ async def test_update_featurestore_empty_call_async(): assert args[0] == featurestore_service.UpdateFeaturestoreRequest() +@pytest.mark.asyncio +async def test_update_featurestore_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_featurestore + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_featurestore + ] = mock_object + + request = {} + await client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_featurestore_async( transport: str = "grpc_asyncio", @@ -2681,6 +3059,9 @@ def test_delete_featurestore_empty_call(): with mock.patch.object( type(client.transport.delete_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_featurestore() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2706,6 +3087,9 @@ def test_delete_featurestore_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_featurestore), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_featurestore(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2714,6 +3098,49 @@ def test_delete_featurestore_non_empty_request_with_auto_populated_field(): ) +def test_delete_featurestore_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_featurestore + ] = mock_rpc + request = {} + client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_featurestore_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2737,6 +3164,56 @@ async def test_delete_featurestore_empty_call_async(): assert args[0] == featurestore_service.DeleteFeaturestoreRequest() +@pytest.mark.asyncio +async def test_delete_featurestore_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_featurestore + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_featurestore + ] = mock_object + + request = {} + await client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_featurestore_async( transport: str = "grpc_asyncio", @@ -2984,6 +3461,9 @@ def test_create_entity_type_empty_call(): with mock.patch.object( type(client.transport.create_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3010,6 +3490,9 @@ def test_create_entity_type_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_entity_type(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3019,6 +3502,49 @@ def test_create_entity_type_non_empty_request_with_auto_populated_field(): ) +def test_create_entity_type_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_entity_type + ] = mock_rpc + request = {} + client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_entity_type_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3042,6 +3568,56 @@ async def test_create_entity_type_empty_call_async(): assert args[0] == featurestore_service.CreateEntityTypeRequest() +@pytest.mark.asyncio +async def test_create_entity_type_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_entity_type + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_entity_type + ] = mock_object + + request = {} + await client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_entity_type_async( transport: str = "grpc_asyncio", @@ -3304,6 +3880,9 @@ def test_get_entity_type_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3327,6 +3906,9 @@ def test_get_entity_type_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_entity_type(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3335,14 +3917,49 @@ def test_get_entity_type_non_empty_request_with_auto_populated_field(): ) -@pytest.mark.asyncio -async def test_get_entity_type_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = FeaturestoreServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) +def test_get_entity_type_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_entity_type in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_entity_type] = mock_rpc + request = {} + client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_entity_type_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_entity_type), "__call__") as call: @@ -3361,6 +3978,52 @@ async def test_get_entity_type_empty_call_async(): assert args[0] == featurestore_service.GetEntityTypeRequest() +@pytest.mark.asyncio +async def test_get_entity_type_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_entity_type + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_entity_type + ] = mock_object + + request = {} + await client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_entity_type_async( transport: str = "grpc_asyncio", @@ -3600,6 +4263,9 @@ def test_list_entity_types_empty_call(): with mock.patch.object( type(client.transport.list_entity_types), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_entity_types() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3628,6 +4294,9 @@ def test_list_entity_types_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_entity_types), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_entity_types(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3639,6 +4308,43 @@ def test_list_entity_types_non_empty_request_with_auto_populated_field(): ) +def test_list_entity_types_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_entity_types in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_entity_types + ] = mock_rpc + request = {} + client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_entity_types(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_entity_types_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3664,6 +4370,52 @@ async def test_list_entity_types_empty_call_async(): assert args[0] == featurestore_service.ListEntityTypesRequest() +@pytest.mark.asyncio +async def test_list_entity_types_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_entity_types + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_entity_types + ] = mock_object + + request = {} + await client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_entity_types(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_entity_types_async( transport: str = "grpc_asyncio", @@ -4111,6 +4863,9 @@ def test_update_entity_type_empty_call(): with mock.patch.object( type(client.transport.update_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4134,12 +4889,54 @@ def test_update_entity_type_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_entity_type(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateEntityTypeRequest() +def test_update_entity_type_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_entity_type + ] = mock_rpc + request = {} + client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_entity_type_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4168,6 +4965,52 @@ async def test_update_entity_type_empty_call_async(): assert args[0] == featurestore_service.UpdateEntityTypeRequest() +@pytest.mark.asyncio +async def test_update_entity_type_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_entity_type + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_entity_type + ] = mock_object + + request = {} + await client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_entity_type_async( transport: str = "grpc_asyncio", @@ -4424,6 +5267,9 @@ def test_delete_entity_type_empty_call(): with mock.patch.object( type(client.transport.delete_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_entity_type() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4449,6 +5295,9 @@ def test_delete_entity_type_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_entity_type), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_entity_type(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4457,6 +5306,49 @@ def test_delete_entity_type_non_empty_request_with_auto_populated_field(): ) +def test_delete_entity_type_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_entity_type + ] = mock_rpc + request = {} + client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_entity_type_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4480,6 +5372,56 @@ async def test_delete_entity_type_empty_call_async(): assert args[0] == featurestore_service.DeleteEntityTypeRequest() +@pytest.mark.asyncio +async def test_delete_entity_type_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_entity_type + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_entity_type + ] = mock_object + + request = {} + await client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_entity_type_async( transport: str = "grpc_asyncio", @@ -4723,6 +5665,9 @@ def test_create_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4747,6 +5692,9 @@ def test_create_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4756,6 +5704,45 @@ def test_create_feature_non_empty_request_with_auto_populated_field(): ) +def test_create_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_feature] = mock_rpc + request = {} + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4777,6 +5764,56 @@ async def test_create_feature_empty_call_async(): assert args[0] == featurestore_service.CreateFeatureRequest() +@pytest.mark.asyncio +async def test_create_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_feature + ] = mock_object + + request = {} + await client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_feature_async( transport: str = "grpc_asyncio", @@ -5024,6 +6061,9 @@ def test_batch_create_features_empty_call(): with mock.patch.object( type(client.transport.batch_create_features), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_features() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5049,6 +6089,9 @@ def test_batch_create_features_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.batch_create_features), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_features(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5057,6 +6100,50 @@ def test_batch_create_features_non_empty_request_with_auto_populated_field(): ) +def test_batch_create_features_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_features + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_features + ] = mock_rpc + request = {} + client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_create_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_features_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5080,6 +6167,56 @@ async def test_batch_create_features_empty_call_async(): assert args[0] == featurestore_service.BatchCreateFeaturesRequest() +@pytest.mark.asyncio +async def test_batch_create_features_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_create_features + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_create_features + ] = mock_object + + request = {} + await client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_create_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_features_async( transport: str = "grpc_asyncio", @@ -5338,6 +6475,9 @@ def test_get_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5361,6 +6501,9 @@ def test_get_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5369,6 +6512,41 @@ def test_get_feature_non_empty_request_with_auto_populated_field(): ) +def test_get_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_feature] = mock_rpc + request = {} + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5398,6 +6576,52 @@ async def test_get_feature_empty_call_async(): assert args[0] == featurestore_service.GetFeatureRequest() +@pytest.mark.asyncio +async def test_get_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_feature + ] = mock_object + + request = {} + await client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_feature_async( transport: str = "grpc_asyncio", request_type=featurestore_service.GetFeatureRequest @@ -5634,6 +6858,9 @@ def test_list_features_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_features() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5660,6 +6887,9 @@ def test_list_features_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_features(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5671,6 +6901,41 @@ def test_list_features_non_empty_request_with_auto_populated_field(): ) +def test_list_features_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_features] = mock_rpc + request = {} + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_features_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5694,6 +6959,52 @@ async def test_list_features_empty_call_async(): assert args[0] == featurestore_service.ListFeaturesRequest() +@pytest.mark.asyncio +async def test_list_features_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_features + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_features + ] = mock_object + + request = {} + await client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_features_async( transport: str = "grpc_asyncio", @@ -6125,6 +7436,9 @@ def test_update_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6146,12 +7460,50 @@ def test_update_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == featurestore_service.UpdateFeatureRequest() +def test_update_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_feature] = mock_rpc + request = {} + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6181,6 +7533,52 @@ async def test_update_feature_empty_call_async(): assert args[0] == featurestore_service.UpdateFeatureRequest() +@pytest.mark.asyncio +async def test_update_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_feature + ] = mock_object + + request = {} + await client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_feature_async( transport: str = "grpc_asyncio", @@ -6425,6 +7823,9 @@ def test_delete_feature_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6448,6 +7849,9 @@ def test_delete_feature_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_feature), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6456,6 +7860,45 @@ def test_delete_feature_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_feature] = mock_rpc + request = {} + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6477,6 +7920,56 @@ async def test_delete_feature_empty_call_async(): assert args[0] == featurestore_service.DeleteFeatureRequest() +@pytest.mark.asyncio +async def test_delete_feature_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature + ] = mock_object + + request = {} + await client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_async( transport: str = "grpc_asyncio", @@ -6704,6 +8197,9 @@ def test_import_feature_values_empty_call(): with mock.patch.object( type(client.transport.import_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6731,6 +8227,9 @@ def test_import_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.import_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6741,6 +8240,50 @@ def test_import_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_import_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.import_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_feature_values + ] = mock_rpc + request = {} + client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6764,6 +8307,56 @@ async def test_import_feature_values_empty_call_async(): assert args[0] == featurestore_service.ImportFeatureValuesRequest() +@pytest.mark.asyncio +async def test_import_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_feature_values + ] = mock_object + + request = {} + await client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.import_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_feature_values_async( transport: str = "grpc_asyncio", @@ -7001,6 +8594,9 @@ def test_batch_read_feature_values_empty_call(): with mock.patch.object( type(client.transport.batch_read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_read_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7026,6 +8622,9 @@ def test_batch_read_feature_values_non_empty_request_with_auto_populated_field() with mock.patch.object( type(client.transport.batch_read_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_read_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7034,6 +8633,50 @@ def test_batch_read_feature_values_non_empty_request_with_auto_populated_field() ) +def test_batch_read_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_read_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_read_feature_values + ] = mock_rpc + request = {} + client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_read_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7057,6 +8700,56 @@ async def test_batch_read_feature_values_empty_call_async(): assert args[0] == featurestore_service.BatchReadFeatureValuesRequest() +@pytest.mark.asyncio +async def test_batch_read_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_read_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_read_feature_values + ] = mock_object + + request = {} + await client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_read_feature_values_async( transport: str = "grpc_asyncio", @@ -7294,6 +8987,9 @@ def test_export_feature_values_empty_call(): with mock.patch.object( type(client.transport.export_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7319,6 +9015,9 @@ def test_export_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.export_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7327,6 +9026,50 @@ def test_export_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_export_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.export_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.export_feature_values + ] = mock_rpc + request = {} + client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7350,6 +9093,56 @@ async def test_export_feature_values_empty_call_async(): assert args[0] == featurestore_service.ExportFeatureValuesRequest() +@pytest.mark.asyncio +async def test_export_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_feature_values + ] = mock_object + + request = {} + await client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.export_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_export_feature_values_async( transport: str = "grpc_asyncio", @@ -7587,6 +9380,9 @@ def test_delete_feature_values_empty_call(): with mock.patch.object( type(client.transport.delete_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_values() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7612,6 +9408,9 @@ def test_delete_feature_values_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_feature_values), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_feature_values(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7620,6 +9419,50 @@ def test_delete_feature_values_non_empty_request_with_auto_populated_field(): ) +def test_delete_feature_values_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_values + ] = mock_rpc + request = {} + client.delete_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_values_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7643,6 +9486,56 @@ async def test_delete_feature_values_empty_call_async(): assert args[0] == featurestore_service.DeleteFeatureValuesRequest() +@pytest.mark.asyncio +async def test_delete_feature_values_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_feature_values + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_feature_values + ] = mock_object + + request = {} + await client.delete_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_feature_values_async( transport: str = "grpc_asyncio", @@ -7879,6 +9772,9 @@ def test_search_features_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.search_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_features() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7904,6 +9800,9 @@ def test_search_features_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.search_features), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_features(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7914,6 +9813,41 @@ def test_search_features_non_empty_request_with_auto_populated_field(): ) +def test_search_features_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.search_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.search_features] = mock_rpc + request = {} + client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_features_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7937,6 +9871,52 @@ async def test_search_features_empty_call_async(): assert args[0] == featurestore_service.SearchFeaturesRequest() +@pytest.mark.asyncio +async def test_search_features_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FeaturestoreServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_features + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_features + ] = mock_object + + request = {} + await client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_features_async( transport: str = "grpc_asyncio", @@ -8442,6 +10422,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_featurestore_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_featurestore + ] = mock_rpc + + request = {} + client.create_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_featurestore_rest_required_fields( request_type=featurestore_service.CreateFeaturestoreRequest, ): @@ -8740,6 +10764,44 @@ def test_get_featurestore_rest(request_type): assert response.online_storage_ttl_days == 2460 +def test_get_featurestore_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_featurestore in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_featurestore + ] = mock_rpc + + request = {} + client.get_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_featurestore_rest_required_fields( request_type=featurestore_service.GetFeaturestoreRequest, ): @@ -9009,6 +11071,46 @@ def test_list_featurestores_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_featurestores_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_featurestores in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_featurestores + ] = mock_rpc + + request = {} + client.list_featurestores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_featurestores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_featurestores_rest_required_fields( request_type=featurestore_service.ListFeaturestoresRequest, ): @@ -9444,6 +11546,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_featurestore_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_featurestore + ] = mock_rpc + + request = {} + client.update_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_featurestore_rest_required_fields( request_type=featurestore_service.UpdateFeaturestoreRequest, ): @@ -9711,6 +11857,50 @@ def test_delete_featurestore_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_featurestore_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_featurestore in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_featurestore + ] = mock_rpc + + request = {} + client.delete_featurestore(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_featurestore(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_featurestore_rest_required_fields( request_type=featurestore_service.DeleteFeaturestoreRequest, ): @@ -10065,6 +12255,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_entity_type_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_entity_type + ] = mock_rpc + + request = {} + client.create_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_entity_type_rest_required_fields( request_type=featurestore_service.CreateEntityTypeRequest, ): @@ -10356,16 +12590,52 @@ def test_get_entity_type_rest(request_type): return_value = entity_type.EntityType.pb(return_value) json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.get_entity_type(request) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_entity_type(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, entity_type.EntityType) + assert response.name == "name_value" + assert response.description == "description_value" + assert response.etag == "etag_value" + assert response.offline_storage_ttl_days == 2554 + + +def test_get_entity_type_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_entity_type in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_entity_type] = mock_rpc + + request = {} + client.get_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_entity_type(request) - # Establish that the response is the type that we expect. - assert isinstance(response, entity_type.EntityType) - assert response.name == "name_value" - assert response.description == "description_value" - assert response.etag == "etag_value" - assert response.offline_storage_ttl_days == 2554 + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_get_entity_type_rest_required_fields( @@ -10641,6 +12911,44 @@ def test_list_entity_types_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_entity_types_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_entity_types in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_entity_types + ] = mock_rpc + + request = {} + client.list_entity_types(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_entity_types(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_entity_types_rest_required_fields( request_type=featurestore_service.ListEntityTypesRequest, ): @@ -11091,6 +13399,46 @@ def get_message_fields(field): assert response.offline_storage_ttl_days == 2554 +def test_update_entity_type_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_entity_type + ] = mock_rpc + + request = {} + client.update_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_entity_type_rest_required_fields( request_type=featurestore_service.UpdateEntityTypeRequest, ): @@ -11363,6 +13711,50 @@ def test_delete_entity_type_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_entity_type_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_entity_type in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_entity_type + ] = mock_rpc + + request = {} + client.delete_entity_type(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_entity_type(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_entity_type_rest_required_fields( request_type=featurestore_service.DeleteEntityTypeRequest, ): @@ -11734,6 +14126,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_feature] = mock_rpc + + request = {} + client.create_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_feature_rest_required_fields( request_type=featurestore_service.CreateFeatureRequest, ): @@ -12027,6 +14459,51 @@ def test_batch_create_features_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_create_features_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_features + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_features + ] = mock_rpc + + request = {} + client.batch_create_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_create_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_create_features_rest_required_fields( request_type=featurestore_service.BatchCreateFeaturesRequest, ): @@ -12321,6 +14798,42 @@ def test_get_feature_rest(request_type): assert response.point_of_contact == "point_of_contact_value" +def test_get_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_feature] = mock_rpc + + request = {} + client.get_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_feature_rest_required_fields( request_type=featurestore_service.GetFeatureRequest, ): @@ -12592,6 +15105,42 @@ def test_list_features_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_features_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_features] = mock_rpc + + request = {} + client.list_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_features_rest_required_fields( request_type=featurestore_service.ListFeaturesRequest, ): @@ -13063,6 +15612,42 @@ def get_message_fields(field): assert response.point_of_contact == "point_of_contact_value" +def test_update_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_feature] = mock_rpc + + request = {} + client.update_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_feature_rest_required_fields( request_type=featurestore_service.UpdateFeatureRequest, ): @@ -13333,6 +15918,46 @@ def test_delete_feature_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_feature in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_feature] = mock_rpc + + request = {} + client.delete_feature(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_rest_required_fields( request_type=featurestore_service.DeleteFeatureRequest, ): @@ -13598,6 +16223,51 @@ def test_import_feature_values_rest(request_type): assert response.operation.name == "operations/spam" +def test_import_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.import_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_feature_values + ] = mock_rpc + + request = {} + client.import_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_import_feature_values_rest_required_fields( request_type=featurestore_service.ImportFeatureValuesRequest, ): @@ -13873,6 +16543,51 @@ def test_batch_read_feature_values_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_read_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_read_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_read_feature_values + ] = mock_rpc + + request = {} + client.batch_read_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_read_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_read_feature_values_rest_required_fields( request_type=featurestore_service.BatchReadFeatureValuesRequest, ): @@ -14149,6 +16864,51 @@ def test_export_feature_values_rest(request_type): assert response.operation.name == "operations/spam" +def test_export_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.export_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.export_feature_values + ] = mock_rpc + + request = {} + client.export_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_export_feature_values_rest_required_fields( request_type=featurestore_service.ExportFeatureValuesRequest, ): @@ -14425,6 +17185,51 @@ def test_delete_feature_values_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_feature_values_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_feature_values + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_feature_values + ] = mock_rpc + + request = {} + client.delete_feature_values(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_feature_values(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_feature_values_rest_required_fields( request_type=featurestore_service.DeleteFeatureValuesRequest, ): @@ -14695,6 +17500,42 @@ def test_search_features_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_features_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FeaturestoreServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.search_features in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.search_features] = mock_rpc + + request = {} + client.search_features(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_features(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_features_rest_required_fields( request_type=featurestore_service.SearchFeaturesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py index e74d482eec..c1c16e7289 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py @@ -1257,6 +1257,9 @@ def test_create_index_endpoint_empty_call(): with mock.patch.object( type(client.transport.create_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1282,6 +1285,9 @@ def test_create_index_endpoint_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1290,6 +1296,50 @@ def test_create_index_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_create_index_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_index_endpoint + ] = mock_rpc + request = {} + client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1313,6 +1363,56 @@ async def test_create_index_endpoint_empty_call_async(): assert args[0] == index_endpoint_service.CreateIndexEndpointRequest() +@pytest.mark.asyncio +async def test_create_index_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_index_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_index_endpoint + ] = mock_object + + request = {} + await client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_endpoint_async( transport: str = "grpc_asyncio", @@ -1577,6 +1677,9 @@ def test_get_index_endpoint_empty_call(): with mock.patch.object( type(client.transport.get_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1602,6 +1705,9 @@ def test_get_index_endpoint_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1610,6 +1716,45 @@ def test_get_index_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_get_index_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_index_endpoint in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_index_endpoint + ] = mock_rpc + request = {} + client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1642,6 +1787,52 @@ async def test_get_index_endpoint_empty_call_async(): assert args[0] == index_endpoint_service.GetIndexEndpointRequest() +@pytest.mark.asyncio +async def test_get_index_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_index_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_index_endpoint + ] = mock_object + + request = {} + await client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_endpoint_async( transport: str = "grpc_asyncio", @@ -1899,6 +2090,9 @@ def test_list_index_endpoints_empty_call(): with mock.patch.object( type(client.transport.list_index_endpoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_index_endpoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1926,6 +2120,9 @@ def test_list_index_endpoints_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_index_endpoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_index_endpoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1936,6 +2133,45 @@ def test_list_index_endpoints_non_empty_request_with_auto_populated_field(): ) +def test_list_index_endpoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_index_endpoints in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_index_endpoints + ] = mock_rpc + request = {} + client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_index_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_index_endpoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1961,6 +2197,52 @@ async def test_list_index_endpoints_empty_call_async(): assert args[0] == index_endpoint_service.ListIndexEndpointsRequest() +@pytest.mark.asyncio +async def test_list_index_endpoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_index_endpoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_index_endpoints + ] = mock_object + + request = {} + await client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_index_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_index_endpoints_async( transport: str = "grpc_asyncio", @@ -2416,6 +2698,9 @@ def test_update_index_endpoint_empty_call(): with mock.patch.object( type(client.transport.update_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2439,12 +2724,55 @@ def test_update_index_endpoint_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_index_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_endpoint_service.UpdateIndexEndpointRequest() +def test_update_index_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_index_endpoint + ] = mock_rpc + request = {} + client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_index_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2477,6 +2805,52 @@ async def test_update_index_endpoint_empty_call_async(): assert args[0] == index_endpoint_service.UpdateIndexEndpointRequest() +@pytest.mark.asyncio +async def test_update_index_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_index_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_index_endpoint + ] = mock_object + + request = {} + await client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_index_endpoint_async( transport: str = "grpc_asyncio", @@ -2741,6 +3115,9 @@ def test_delete_index_endpoint_empty_call(): with mock.patch.object( type(client.transport.delete_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index_endpoint() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2766,6 +3143,9 @@ def test_delete_index_endpoint_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_index_endpoint), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index_endpoint(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2774,6 +3154,50 @@ def test_delete_index_endpoint_non_empty_request_with_auto_populated_field(): ) +def test_delete_index_endpoint_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_index_endpoint + ] = mock_rpc + request = {} + client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_endpoint_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2797,6 +3221,56 @@ async def test_delete_index_endpoint_empty_call_async(): assert args[0] == index_endpoint_service.DeleteIndexEndpointRequest() +@pytest.mark.asyncio +async def test_delete_index_endpoint_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_index_endpoint + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_index_endpoint + ] = mock_object + + request = {} + await client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_endpoint_async( transport: str = "grpc_asyncio", @@ -3030,6 +3504,9 @@ def test_deploy_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.deploy_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3053,6 +3530,9 @@ def test_deploy_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.deploy_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.deploy_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3061,6 +3541,45 @@ def test_deploy_index_non_empty_request_with_auto_populated_field(): ) +def test_deploy_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.deploy_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.deploy_index] = mock_rpc + request = {} + client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.deploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_deploy_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3082,6 +3601,56 @@ async def test_deploy_index_empty_call_async(): assert args[0] == index_endpoint_service.DeployIndexRequest() +@pytest.mark.asyncio +async def test_deploy_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.deploy_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.deploy_index + ] = mock_object + + request = {} + await client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.deploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_deploy_index_async( transport: str = "grpc_asyncio", @@ -3315,6 +3884,9 @@ def test_undeploy_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.undeploy_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3339,6 +3911,9 @@ def test_undeploy_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.undeploy_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.undeploy_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3348,6 +3923,45 @@ def test_undeploy_index_non_empty_request_with_auto_populated_field(): ) +def test_undeploy_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.undeploy_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.undeploy_index] = mock_rpc + request = {} + client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.undeploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_undeploy_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3369,6 +3983,56 @@ async def test_undeploy_index_empty_call_async(): assert args[0] == index_endpoint_service.UndeployIndexRequest() +@pytest.mark.asyncio +async def test_undeploy_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.undeploy_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.undeploy_index + ] = mock_object + + request = {} + await client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.undeploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_undeploy_index_async( transport: str = "grpc_asyncio", @@ -3606,6 +4270,9 @@ def test_mutate_deployed_index_empty_call(): with mock.patch.object( type(client.transport.mutate_deployed_index), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.mutate_deployed_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3631,6 +4298,9 @@ def test_mutate_deployed_index_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.mutate_deployed_index), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.mutate_deployed_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3639,6 +4309,50 @@ def test_mutate_deployed_index_non_empty_request_with_auto_populated_field(): ) +def test_mutate_deployed_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.mutate_deployed_index + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.mutate_deployed_index + ] = mock_rpc + request = {} + client.mutate_deployed_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.mutate_deployed_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_mutate_deployed_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3662,6 +4376,56 @@ async def test_mutate_deployed_index_empty_call_async(): assert args[0] == index_endpoint_service.MutateDeployedIndexRequest() +@pytest.mark.asyncio +async def test_mutate_deployed_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexEndpointServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.mutate_deployed_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.mutate_deployed_index + ] = mock_object + + request = {} + await client.mutate_deployed_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.mutate_deployed_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_mutate_deployed_index_async( transport: str = "grpc_asyncio", @@ -4039,6 +4803,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_index_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_index_endpoint + ] = mock_rpc + + request = {} + client.create_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_index_endpoint_rest_required_fields( request_type=index_endpoint_service.CreateIndexEndpointRequest, ): @@ -4329,6 +5138,46 @@ def test_get_index_endpoint_rest(request_type): assert response.public_endpoint_domain_name == "public_endpoint_domain_name_value" +def test_get_index_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_index_endpoint in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_index_endpoint + ] = mock_rpc + + request = {} + client.get_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_index_endpoint_rest_required_fields( request_type=index_endpoint_service.GetIndexEndpointRequest, ): @@ -4600,6 +5449,46 @@ def test_list_index_endpoints_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_index_endpoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_index_endpoints in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_index_endpoints + ] = mock_rpc + + request = {} + client.list_index_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_index_endpoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_index_endpoints_rest_required_fields( request_type=index_endpoint_service.ListIndexEndpointsRequest, ): @@ -5111,6 +6000,47 @@ def get_message_fields(field): assert response.public_endpoint_domain_name == "public_endpoint_domain_name_value" +def test_update_index_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_index_endpoint + ] = mock_rpc + + request = {} + client.update_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_index_endpoint_rest_required_fields( request_type=index_endpoint_service.UpdateIndexEndpointRequest, ): @@ -5390,6 +6320,51 @@ def test_delete_index_endpoint_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_index_endpoint_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_index_endpoint + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_index_endpoint + ] = mock_rpc + + request = {} + client.delete_index_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index_endpoint(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_index_endpoint_rest_required_fields( request_type=index_endpoint_service.DeleteIndexEndpointRequest, ): @@ -5654,6 +6629,46 @@ def test_deploy_index_rest(request_type): assert response.operation.name == "operations/spam" +def test_deploy_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.deploy_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.deploy_index] = mock_rpc + + request = {} + client.deploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.deploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_deploy_index_rest_required_fields( request_type=index_endpoint_service.DeployIndexRequest, ): @@ -5930,6 +6945,46 @@ def test_undeploy_index_rest(request_type): assert response.operation.name == "operations/spam" +def test_undeploy_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.undeploy_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.undeploy_index] = mock_rpc + + request = {} + client.undeploy_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.undeploy_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_undeploy_index_rest_required_fields( request_type=index_endpoint_service.UndeployIndexRequest, ): @@ -6323,6 +7378,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_mutate_deployed_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexEndpointServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.mutate_deployed_index + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.mutate_deployed_index + ] = mock_rpc + + request = {} + client.mutate_deployed_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.mutate_deployed_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_mutate_deployed_index_rest_required_fields( request_type=index_endpoint_service.MutateDeployedIndexRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py index ed4baa3a43..ef7c84648b 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py @@ -1159,6 +1159,9 @@ def test_create_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1182,6 +1185,9 @@ def test_create_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1190,6 +1196,45 @@ def test_create_index_non_empty_request_with_auto_populated_field(): ) +def test_create_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_index] = mock_rpc + request = {} + client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1211,6 +1256,56 @@ async def test_create_index_empty_call_async(): assert args[0] == index_service.CreateIndexRequest() +@pytest.mark.asyncio +async def test_create_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_index + ] = mock_object + + request = {} + await client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_index_async( transport: str = "grpc_asyncio", request_type=index_service.CreateIndexRequest @@ -1456,6 +1551,9 @@ def test_get_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1479,6 +1577,9 @@ def test_get_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1487,6 +1588,41 @@ def test_get_index_non_empty_request_with_auto_populated_field(): ) +def test_get_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_index] = mock_rpc + request = {} + client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1515,6 +1651,50 @@ async def test_get_index_empty_call_async(): assert args[0] == index_service.GetIndexRequest() +@pytest.mark.asyncio +async def test_get_index_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_index + ] = mock_object + + request = {} + await client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_index_async( transport: str = "grpc_asyncio", request_type=index_service.GetIndexRequest @@ -1749,6 +1929,9 @@ def test_list_indexes_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_indexes() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1774,6 +1957,9 @@ def test_list_indexes_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_indexes), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_indexes(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1784,6 +1970,41 @@ def test_list_indexes_non_empty_request_with_auto_populated_field(): ) +def test_list_indexes_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_indexes in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_indexes] = mock_rpc + request = {} + client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_indexes_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1807,6 +2028,52 @@ async def test_list_indexes_empty_call_async(): assert args[0] == index_service.ListIndexesRequest() +@pytest.mark.asyncio +async def test_list_indexes_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_indexes + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_indexes + ] = mock_object + + request = {} + await client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_indexes_async( transport: str = "grpc_asyncio", request_type=index_service.ListIndexesRequest @@ -2222,6 +2489,9 @@ def test_update_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2243,12 +2513,54 @@ def test_update_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == index_service.UpdateIndexRequest() +def test_update_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_index] = mock_rpc + request = {} + client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2270,6 +2582,56 @@ async def test_update_index_empty_call_async(): assert args[0] == index_service.UpdateIndexRequest() +@pytest.mark.asyncio +async def test_update_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_index + ] = mock_object + + request = {} + await client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_index_async( transport: str = "grpc_asyncio", request_type=index_service.UpdateIndexRequest @@ -2502,6 +2864,9 @@ def test_delete_index_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2525,6 +2890,9 @@ def test_delete_index_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_index), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_index(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2533,6 +2901,45 @@ def test_delete_index_non_empty_request_with_auto_populated_field(): ) +def test_delete_index_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_index] = mock_rpc + request = {} + client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2554,6 +2961,56 @@ async def test_delete_index_empty_call_async(): assert args[0] == index_service.DeleteIndexRequest() +@pytest.mark.asyncio +async def test_delete_index_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_index + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_index + ] = mock_object + + request = {} + await client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_index_async( transport: str = "grpc_asyncio", request_type=index_service.DeleteIndexRequest @@ -2780,6 +3237,9 @@ def test_upsert_datapoints_empty_call(): with mock.patch.object( type(client.transport.upsert_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upsert_datapoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2805,6 +3265,9 @@ def test_upsert_datapoints_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.upsert_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upsert_datapoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2813,6 +3276,43 @@ def test_upsert_datapoints_non_empty_request_with_auto_populated_field(): ) +def test_upsert_datapoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upsert_datapoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.upsert_datapoints + ] = mock_rpc + request = {} + client.upsert_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.upsert_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_upsert_datapoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2836,6 +3336,52 @@ async def test_upsert_datapoints_empty_call_async(): assert args[0] == index_service.UpsertDatapointsRequest() +@pytest.mark.asyncio +async def test_upsert_datapoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.upsert_datapoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.upsert_datapoints + ] = mock_object + + request = {} + await client.upsert_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.upsert_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_upsert_datapoints_async( transport: str = "grpc_asyncio", request_type=index_service.UpsertDatapointsRequest @@ -2986,6 +3532,9 @@ def test_remove_datapoints_empty_call(): with mock.patch.object( type(client.transport.remove_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.remove_datapoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3011,6 +3560,9 @@ def test_remove_datapoints_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.remove_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.remove_datapoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3019,6 +3571,43 @@ def test_remove_datapoints_non_empty_request_with_auto_populated_field(): ) +def test_remove_datapoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.remove_datapoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.remove_datapoints + ] = mock_rpc + request = {} + client.remove_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.remove_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_remove_datapoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3042,6 +3631,52 @@ async def test_remove_datapoints_empty_call_async(): assert args[0] == index_service.RemoveDatapointsRequest() +@pytest.mark.asyncio +async def test_remove_datapoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = IndexServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.remove_datapoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.remove_datapoints + ] = mock_object + + request = {} + await client.remove_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.remove_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_remove_datapoints_async( transport: str = "grpc_asyncio", request_type=index_service.RemoveDatapointsRequest @@ -3275,6 +3910,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_index] = mock_rpc + + request = {} + client.create_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_index_rest_required_fields( request_type=index_service.CreateIndexRequest, ): @@ -3560,6 +4235,42 @@ def test_get_index_rest(request_type): assert response.index_update_method == index.Index.IndexUpdateMethod.BATCH_UPDATE +def test_get_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_index] = mock_rpc + + request = {} + client.get_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_index_rest_required_fields(request_type=index_service.GetIndexRequest): transport_class = transports.IndexServiceRestTransport @@ -3821,6 +4532,42 @@ def test_list_indexes_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_indexes_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_indexes in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_indexes] = mock_rpc + + request = {} + client.list_indexes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_indexes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_indexes_rest_required_fields( request_type=index_service.ListIndexesRequest, ): @@ -4254,6 +5001,46 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_index] = mock_rpc + + request = {} + client.update_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_index_rest_required_fields( request_type=index_service.UpdateIndexRequest, ): @@ -4517,6 +5304,46 @@ def test_delete_index_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_index_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_index in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_index] = mock_rpc + + request = {} + client.delete_index(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_index(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_index_rest_required_fields( request_type=index_service.DeleteIndexRequest, ): @@ -4778,6 +5605,44 @@ def test_upsert_datapoints_rest(request_type): assert isinstance(response, index_service.UpsertDatapointsResponse) +def test_upsert_datapoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upsert_datapoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.upsert_datapoints + ] = mock_rpc + + request = {} + client.upsert_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.upsert_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_upsert_datapoints_rest_required_fields( request_type=index_service.UpsertDatapointsRequest, ): @@ -4986,6 +5851,44 @@ def test_remove_datapoints_rest(request_type): assert isinstance(response, index_service.RemoveDatapointsResponse) +def test_remove_datapoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = IndexServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.remove_datapoints in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.remove_datapoints + ] = mock_rpc + + request = {} + client.remove_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.remove_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_remove_datapoints_rest_required_fields( request_type=index_service.RemoveDatapointsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index 5afe93c76e..9af7933b22 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -1189,6 +1189,9 @@ def test_create_custom_job_empty_call(): with mock.patch.object( type(client.transport.create_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1214,6 +1217,9 @@ def test_create_custom_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_custom_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1222,6 +1228,43 @@ def test_create_custom_job_non_empty_request_with_auto_populated_field(): ) +def test_create_custom_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_custom_job + ] = mock_rpc + request = {} + client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_custom_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1249,6 +1292,52 @@ async def test_create_custom_job_empty_call_async(): assert args[0] == job_service.CreateCustomJobRequest() +@pytest.mark.asyncio +async def test_create_custom_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_custom_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_custom_job + ] = mock_object + + request = {} + await client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest @@ -1505,6 +1594,9 @@ def test_get_custom_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1528,6 +1620,9 @@ def test_get_custom_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_custom_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1536,6 +1631,41 @@ def test_get_custom_job_non_empty_request_with_auto_populated_field(): ) +def test_get_custom_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_custom_job] = mock_rpc + request = {} + client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_custom_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1561,6 +1691,52 @@ async def test_get_custom_job_empty_call_async(): assert args[0] == job_service.GetCustomJobRequest() +@pytest.mark.asyncio +async def test_get_custom_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_custom_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_custom_job + ] = mock_object + + request = {} + await client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest @@ -1793,6 +1969,9 @@ def test_list_custom_jobs_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_custom_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1818,6 +1997,9 @@ def test_list_custom_jobs_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_custom_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1828,6 +2010,43 @@ def test_list_custom_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_custom_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_custom_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_custom_jobs + ] = mock_rpc + request = {} + client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_custom_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_custom_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1851,6 +2070,52 @@ async def test_list_custom_jobs_empty_call_async(): assert args[0] == job_service.ListCustomJobsRequest() +@pytest.mark.asyncio +async def test_list_custom_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_custom_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_custom_jobs + ] = mock_object + + request = {} + await client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_custom_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_custom_jobs_async( transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest @@ -2270,6 +2535,9 @@ def test_delete_custom_job_empty_call(): with mock.patch.object( type(client.transport.delete_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2295,6 +2563,9 @@ def test_delete_custom_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_custom_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2303,6 +2574,47 @@ def test_delete_custom_job_non_empty_request_with_auto_populated_field(): ) +def test_delete_custom_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_custom_job + ] = mock_rpc + request = {} + client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_custom_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2326,6 +2638,56 @@ async def test_delete_custom_job_empty_call_async(): assert args[0] == job_service.DeleteCustomJobRequest() +@pytest.mark.asyncio +async def test_delete_custom_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_custom_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_custom_job + ] = mock_object + + request = {} + await client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest @@ -2562,6 +2924,9 @@ def test_cancel_custom_job_empty_call(): with mock.patch.object( type(client.transport.cancel_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_custom_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2587,6 +2952,9 @@ def test_cancel_custom_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_custom_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_custom_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2595,6 +2963,43 @@ def test_cancel_custom_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_custom_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_custom_job + ] = mock_rpc + request = {} + client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_custom_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2616,6 +3021,52 @@ async def test_cancel_custom_job_empty_call_async(): assert args[0] == job_service.CancelCustomJobRequest() +@pytest.mark.asyncio +async def test_cancel_custom_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_custom_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_custom_job + ] = mock_object + + request = {} + await client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest @@ -2865,6 +3316,9 @@ def test_create_data_labeling_job_empty_call(): with mock.patch.object( type(client.transport.create_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2890,6 +3344,9 @@ def test_create_data_labeling_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_data_labeling_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2898,6 +3355,46 @@ def test_create_data_labeling_job_non_empty_request_with_auto_populated_field(): ) +def test_create_data_labeling_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_data_labeling_job + ] = mock_rpc + request = {} + client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_data_labeling_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2931,6 +3428,52 @@ async def test_create_data_labeling_job_empty_call_async(): assert args[0] == job_service.CreateDataLabelingJobRequest() +@pytest.mark.asyncio +async def test_create_data_labeling_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_data_labeling_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_data_labeling_job + ] = mock_object + + request = {} + await client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -3216,6 +3759,9 @@ def test_get_data_labeling_job_empty_call(): with mock.patch.object( type(client.transport.get_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3241,6 +3787,9 @@ def test_get_data_labeling_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_data_labeling_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3249,6 +3798,46 @@ def test_get_data_labeling_job_non_empty_request_with_auto_populated_field(): ) +def test_get_data_labeling_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_data_labeling_job + ] = mock_rpc + request = {} + client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_data_labeling_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3283,11 +3872,57 @@ async def test_get_data_labeling_job_empty_call_async(): @pytest.mark.asyncio -async def test_get_data_labeling_job_async( - transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +async def test_get_data_labeling_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = JobServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_data_labeling_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_data_labeling_job + ] = mock_object + + request = {} + await client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +): + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3540,6 +4175,9 @@ def test_list_data_labeling_jobs_empty_call(): with mock.patch.object( type(client.transport.list_data_labeling_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_data_labeling_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3568,6 +4206,9 @@ def test_list_data_labeling_jobs_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_data_labeling_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_data_labeling_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3579,6 +4220,46 @@ def test_list_data_labeling_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_data_labeling_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_data_labeling_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_data_labeling_jobs + ] = mock_rpc + request = {} + client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_data_labeling_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_data_labeling_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3604,6 +4285,52 @@ async def test_list_data_labeling_jobs_empty_call_async(): assert args[0] == job_service.ListDataLabelingJobsRequest() +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_data_labeling_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_data_labeling_jobs + ] = mock_object + + request = {} + await client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_data_labeling_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async( transport: str = "grpc_asyncio", @@ -4042,6 +4769,9 @@ def test_delete_data_labeling_job_empty_call(): with mock.patch.object( type(client.transport.delete_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4067,6 +4797,9 @@ def test_delete_data_labeling_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_data_labeling_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4075,6 +4808,50 @@ def test_delete_data_labeling_job_non_empty_request_with_auto_populated_field(): ) +def test_delete_data_labeling_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_data_labeling_job + ] = mock_rpc + request = {} + client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_data_labeling_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4098,6 +4875,56 @@ async def test_delete_data_labeling_job_empty_call_async(): assert args[0] == job_service.DeleteDataLabelingJobRequest() +@pytest.mark.asyncio +async def test_delete_data_labeling_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_data_labeling_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_data_labeling_job + ] = mock_object + + request = {} + await client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -4335,6 +5162,9 @@ def test_cancel_data_labeling_job_empty_call(): with mock.patch.object( type(client.transport.cancel_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_data_labeling_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4360,6 +5190,9 @@ def test_cancel_data_labeling_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_data_labeling_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_data_labeling_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4368,6 +5201,46 @@ def test_cancel_data_labeling_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_data_labeling_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_data_labeling_job + ] = mock_rpc + request = {} + client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_data_labeling_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4389,6 +5262,52 @@ async def test_cancel_data_labeling_job_empty_call_async(): assert args[0] == job_service.CancelDataLabelingJobRequest() +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_data_labeling_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_data_labeling_job + ] = mock_object + + request = {} + await client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -4633,6 +5552,9 @@ def test_create_hyperparameter_tuning_job_empty_call(): with mock.patch.object( type(client.transport.create_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4658,6 +5580,9 @@ def test_create_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.create_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_hyperparameter_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4666,6 +5591,46 @@ def test_create_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ ) +def test_create_hyperparameter_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_hyperparameter_tuning_job + ] = mock_rpc + request = {} + client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4696,6 +5661,52 @@ async def test_create_hyperparameter_tuning_job_empty_call_async(): assert args[0] == job_service.CreateHyperparameterTuningJobRequest() +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_hyperparameter_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_hyperparameter_tuning_job + ] = mock_object + + request = {} + await client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -4981,6 +5992,9 @@ def test_get_hyperparameter_tuning_job_empty_call(): with mock.patch.object( type(client.transport.get_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5006,6 +6020,9 @@ def test_get_hyperparameter_tuning_job_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.get_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_hyperparameter_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5014,6 +6031,46 @@ def test_get_hyperparameter_tuning_job_non_empty_request_with_auto_populated_fie ) +def test_get_hyperparameter_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_hyperparameter_tuning_job + ] = mock_rpc + request = {} + client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5044,6 +6101,52 @@ async def test_get_hyperparameter_tuning_job_empty_call_async(): assert args[0] == job_service.GetHyperparameterTuningJobRequest() +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_hyperparameter_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_hyperparameter_tuning_job + ] = mock_object + + request = {} + await client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -5297,6 +6400,9 @@ def test_list_hyperparameter_tuning_jobs_empty_call(): with mock.patch.object( type(client.transport.list_hyperparameter_tuning_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_hyperparameter_tuning_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5324,6 +6430,9 @@ def test_list_hyperparameter_tuning_jobs_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.list_hyperparameter_tuning_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_hyperparameter_tuning_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5334,6 +6443,46 @@ def test_list_hyperparameter_tuning_jobs_non_empty_request_with_auto_populated_f ) +def test_list_hyperparameter_tuning_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_hyperparameter_tuning_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_hyperparameter_tuning_jobs + ] = mock_rpc + request = {} + client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_hyperparameter_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5360,17 +6509,63 @@ async def test_list_hyperparameter_tuning_jobs_empty_call_async(): @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async( +async def test_list_hyperparameter_tuning_jobs_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=job_service.ListHyperparameterTuningJobsRequest, ): - client = JobServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_hyperparameter_tuning_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_hyperparameter_tuning_jobs + ] = mock_object + + request = {} + await client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_hyperparameter_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. @@ -5803,6 +6998,9 @@ def test_delete_hyperparameter_tuning_job_empty_call(): with mock.patch.object( type(client.transport.delete_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5828,6 +7026,9 @@ def test_delete_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.delete_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_hyperparameter_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5836,6 +7037,50 @@ def test_delete_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ ) +def test_delete_hyperparameter_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_hyperparameter_tuning_job + ] = mock_rpc + request = {} + client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5859,6 +7104,56 @@ async def test_delete_hyperparameter_tuning_job_empty_call_async(): assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_hyperparameter_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_hyperparameter_tuning_job + ] = mock_object + + request = {} + await client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -6096,6 +7391,9 @@ def test_cancel_hyperparameter_tuning_job_empty_call(): with mock.patch.object( type(client.transport.cancel_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_hyperparameter_tuning_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6121,6 +7419,9 @@ def test_cancel_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.cancel_hyperparameter_tuning_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_hyperparameter_tuning_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6129,6 +7430,46 @@ def test_cancel_hyperparameter_tuning_job_non_empty_request_with_auto_populated_ ) +def test_cancel_hyperparameter_tuning_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_hyperparameter_tuning_job + ] = mock_rpc + request = {} + client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6150,6 +7491,52 @@ async def test_cancel_hyperparameter_tuning_job_empty_call_async(): assert args[0] == job_service.CancelHyperparameterTuningJobRequest() +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_hyperparameter_tuning_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_hyperparameter_tuning_job + ] = mock_object + + request = {} + await client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -6386,6 +7773,9 @@ def test_create_nas_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_nas_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6409,6 +7799,9 @@ def test_create_nas_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_nas_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6417,6 +7810,41 @@ def test_create_nas_job_non_empty_request_with_auto_populated_field(): ) +def test_create_nas_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_nas_job] = mock_rpc + request = {} + client.create_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_nas_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6443,6 +7871,52 @@ async def test_create_nas_job_empty_call_async(): assert args[0] == job_service.CreateNasJobRequest() +@pytest.mark.asyncio +async def test_create_nas_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_nas_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_nas_job + ] = mock_object + + request = {} + await client.create_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_nas_job_async( transport: str = "grpc_asyncio", request_type=job_service.CreateNasJobRequest @@ -6689,6 +8163,9 @@ def test_get_nas_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_nas_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6712,6 +8189,9 @@ def test_get_nas_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_nas_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6720,6 +8200,41 @@ def test_get_nas_job_non_empty_request_with_auto_populated_field(): ) +def test_get_nas_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_nas_job] = mock_rpc + request = {} + client.get_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_nas_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6746,6 +8261,52 @@ async def test_get_nas_job_empty_call_async(): assert args[0] == job_service.GetNasJobRequest() +@pytest.mark.asyncio +async def test_get_nas_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_nas_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_nas_job + ] = mock_object + + request = {} + await client.get_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_nas_job_async( transport: str = "grpc_asyncio", request_type=job_service.GetNasJobRequest @@ -6976,6 +8537,9 @@ def test_list_nas_jobs_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_nas_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_nas_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7001,6 +8565,9 @@ def test_list_nas_jobs_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_nas_jobs), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_nas_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7011,6 +8578,41 @@ def test_list_nas_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_nas_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_nas_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_nas_jobs] = mock_rpc + request = {} + client.list_nas_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_nas_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_nas_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7034,6 +8636,52 @@ async def test_list_nas_jobs_empty_call_async(): assert args[0] == job_service.ListNasJobsRequest() +@pytest.mark.asyncio +async def test_list_nas_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_nas_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_nas_jobs + ] = mock_object + + request = {} + await client.list_nas_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_nas_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_nas_jobs_async( transport: str = "grpc_asyncio", request_type=job_service.ListNasJobsRequest @@ -7449,6 +9097,9 @@ def test_delete_nas_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_nas_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7472,6 +9123,9 @@ def test_delete_nas_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_nas_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7480,6 +9134,45 @@ def test_delete_nas_job_non_empty_request_with_auto_populated_field(): ) +def test_delete_nas_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_nas_job] = mock_rpc + request = {} + client.delete_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_nas_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7502,21 +9195,71 @@ async def test_delete_nas_job_empty_call_async(): @pytest.mark.asyncio -async def test_delete_nas_job_async( - transport: str = "grpc_asyncio", request_type=job_service.DeleteNasJobRequest +async def test_delete_nas_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = JobServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_nas_job), "__call__") as call: - # Designate an appropriate return value for the call. + # Ensure method has been cached + assert ( + client._client._transport.delete_nas_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_nas_job + ] = mock_object + + request = {} + await client.delete_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_nas_job_async( + transport: str = "grpc_asyncio", request_type=job_service.DeleteNasJobRequest +): + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_nas_job), "__call__") as call: + # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") ) @@ -7723,6 +9466,9 @@ def test_cancel_nas_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.cancel_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_nas_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7746,6 +9492,9 @@ def test_cancel_nas_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.cancel_nas_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_nas_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7754,6 +9503,41 @@ def test_cancel_nas_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_nas_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.cancel_nas_job] = mock_rpc + request = {} + client.cancel_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_nas_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7773,6 +9557,52 @@ async def test_cancel_nas_job_empty_call_async(): assert args[0] == job_service.CancelNasJobRequest() +@pytest.mark.asyncio +async def test_cancel_nas_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_nas_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_nas_job + ] = mock_object + + request = {} + await client.cancel_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_nas_job_async( transport: str = "grpc_asyncio", request_type=job_service.CancelNasJobRequest @@ -7998,6 +9828,9 @@ def test_get_nas_trial_detail_empty_call(): with mock.patch.object( type(client.transport.get_nas_trial_detail), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_nas_trial_detail() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8023,6 +9856,9 @@ def test_get_nas_trial_detail_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_nas_trial_detail), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_nas_trial_detail(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8031,6 +9867,45 @@ def test_get_nas_trial_detail_non_empty_request_with_auto_populated_field(): ) +def test_get_nas_trial_detail_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_nas_trial_detail in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_nas_trial_detail + ] = mock_rpc + request = {} + client.get_nas_trial_detail(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_nas_trial_detail(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_nas_trial_detail_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8057,6 +9932,52 @@ async def test_get_nas_trial_detail_empty_call_async(): assert args[0] == job_service.GetNasTrialDetailRequest() +@pytest.mark.asyncio +async def test_get_nas_trial_detail_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_nas_trial_detail + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_nas_trial_detail + ] = mock_object + + request = {} + await client.get_nas_trial_detail(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_nas_trial_detail(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_nas_trial_detail_async( transport: str = "grpc_asyncio", request_type=job_service.GetNasTrialDetailRequest @@ -8301,6 +10222,9 @@ def test_list_nas_trial_details_empty_call(): with mock.patch.object( type(client.transport.list_nas_trial_details), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_nas_trial_details() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8327,6 +10251,9 @@ def test_list_nas_trial_details_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_nas_trial_details), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_nas_trial_details(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8336,6 +10263,46 @@ def test_list_nas_trial_details_non_empty_request_with_auto_populated_field(): ) +def test_list_nas_trial_details_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_nas_trial_details + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_nas_trial_details + ] = mock_rpc + request = {} + client.list_nas_trial_details(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_nas_trial_details(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_nas_trial_details_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8361,6 +10328,52 @@ async def test_list_nas_trial_details_empty_call_async(): assert args[0] == job_service.ListNasTrialDetailsRequest() +@pytest.mark.asyncio +async def test_list_nas_trial_details_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_nas_trial_details + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_nas_trial_details + ] = mock_object + + request = {} + await client.list_nas_trial_details(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_nas_trial_details(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_nas_trial_details_async( transport: str = "grpc_asyncio", request_type=job_service.ListNasTrialDetailsRequest @@ -8815,6 +10828,9 @@ def test_create_batch_prediction_job_empty_call(): with mock.patch.object( type(client.transport.create_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8840,6 +10856,9 @@ def test_create_batch_prediction_job_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.create_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_batch_prediction_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8848,6 +10867,46 @@ def test_create_batch_prediction_job_non_empty_request_with_auto_populated_field ) +def test_create_batch_prediction_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_batch_prediction_job + ] = mock_rpc + request = {} + client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_batch_prediction_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8880,6 +10939,52 @@ async def test_create_batch_prediction_job_empty_call_async(): assert args[0] == job_service.CreateBatchPredictionJobRequest() +@pytest.mark.asyncio +async def test_create_batch_prediction_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_batch_prediction_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_batch_prediction_job + ] = mock_object + + request = {} + await client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -9169,6 +11274,9 @@ def test_get_batch_prediction_job_empty_call(): with mock.patch.object( type(client.transport.get_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9194,6 +11302,9 @@ def test_get_batch_prediction_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_batch_prediction_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9202,6 +11313,46 @@ def test_get_batch_prediction_job_non_empty_request_with_auto_populated_field(): ) +def test_get_batch_prediction_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_batch_prediction_job + ] = mock_rpc + request = {} + client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_batch_prediction_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9234,6 +11385,52 @@ async def test_get_batch_prediction_job_empty_call_async(): assert args[0] == job_service.GetBatchPredictionJobRequest() +@pytest.mark.asyncio +async def test_get_batch_prediction_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_batch_prediction_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_batch_prediction_job + ] = mock_object + + request = {} + await client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -9491,6 +11688,9 @@ def test_list_batch_prediction_jobs_empty_call(): with mock.patch.object( type(client.transport.list_batch_prediction_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_batch_prediction_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9518,6 +11718,9 @@ def test_list_batch_prediction_jobs_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.list_batch_prediction_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_batch_prediction_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9528,6 +11731,46 @@ def test_list_batch_prediction_jobs_non_empty_request_with_auto_populated_field( ) +def test_list_batch_prediction_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_batch_prediction_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_batch_prediction_jobs + ] = mock_rpc + request = {} + client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_batch_prediction_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9553,6 +11796,52 @@ async def test_list_batch_prediction_jobs_empty_call_async(): assert args[0] == job_service.ListBatchPredictionJobsRequest() +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_batch_prediction_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_batch_prediction_jobs + ] = mock_object + + request = {} + await client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_batch_prediction_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async( transport: str = "grpc_asyncio", @@ -9995,6 +12284,9 @@ def test_delete_batch_prediction_job_empty_call(): with mock.patch.object( type(client.transport.delete_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10020,6 +12312,9 @@ def test_delete_batch_prediction_job_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.delete_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_batch_prediction_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10028,27 +12323,121 @@ def test_delete_batch_prediction_job_non_empty_request_with_auto_populated_field ) -@pytest.mark.asyncio -async def test_delete_batch_prediction_job_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = JobServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) +def test_delete_batch_prediction_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_batch_prediction_job + ] = mock_rpc + request = {} + client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_batch_prediction_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_batch_prediction_job + ] = mock_object + + request = {} + await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_batch_prediction_job(request) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - response = await client.delete_batch_prediction_job() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == job_service.DeleteBatchPredictionJobRequest() + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio @@ -10288,6 +12677,9 @@ def test_cancel_batch_prediction_job_empty_call(): with mock.patch.object( type(client.transport.cancel_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_batch_prediction_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10313,6 +12705,9 @@ def test_cancel_batch_prediction_job_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.cancel_batch_prediction_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_batch_prediction_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10321,6 +12716,46 @@ def test_cancel_batch_prediction_job_non_empty_request_with_auto_populated_field ) +def test_cancel_batch_prediction_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_batch_prediction_job + ] = mock_rpc + request = {} + client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_batch_prediction_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10342,6 +12777,52 @@ async def test_cancel_batch_prediction_job_empty_call_async(): assert args[0] == job_service.CancelBatchPredictionJobRequest() +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_batch_prediction_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_batch_prediction_job + ] = mock_object + + request = {} + await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -10595,6 +13076,9 @@ def test_create_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.create_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10620,6 +13104,9 @@ def test_create_model_deployment_monitoring_job_non_empty_request_with_auto_popu with mock.patch.object( type(client.transport.create_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10628,6 +13115,46 @@ def test_create_model_deployment_monitoring_job_non_empty_request_with_auto_popu ) +def test_create_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10660,6 +13187,52 @@ async def test_create_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.CreateModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_create_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -10958,6 +13531,9 @@ def test_search_model_deployment_monitoring_stats_anomalies_empty_call(): type(client.transport.search_model_deployment_monitoring_stats_anomalies), "__call__", ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_model_deployment_monitoring_stats_anomalies() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10990,6 +13566,9 @@ def test_search_model_deployment_monitoring_stats_anomalies_non_empty_request_wi type(client.transport.search_model_deployment_monitoring_stats_anomalies), "__call__", ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_model_deployment_monitoring_stats_anomalies(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11003,6 +13582,46 @@ def test_search_model_deployment_monitoring_stats_anomalies_non_empty_request_wi ) +def test_search_model_deployment_monitoring_stats_anomalies_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_model_deployment_monitoring_stats_anomalies + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_model_deployment_monitoring_stats_anomalies + ] = mock_rpc + request = {} + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11032,6 +13651,52 @@ async def test_search_model_deployment_monitoring_stats_anomalies_empty_call_asy ) +@pytest.mark.asyncio +async def test_search_model_deployment_monitoring_stats_anomalies_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_model_deployment_monitoring_stats_anomalies + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_model_deployment_monitoring_stats_anomalies + ] = mock_object + + request = {} + await client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_model_deployment_monitoring_stats_anomalies_async( transport: str = "grpc_asyncio", @@ -11539,6 +14204,9 @@ def test_get_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.get_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11564,6 +14232,9 @@ def test_get_model_deployment_monitoring_job_non_empty_request_with_auto_populat with mock.patch.object( type(client.transport.get_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11572,6 +14243,46 @@ def test_get_model_deployment_monitoring_job_non_empty_request_with_auto_populat ) +def test_get_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11604,6 +14315,52 @@ async def test_get_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.GetModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_get_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -11872,6 +14629,9 @@ def test_list_model_deployment_monitoring_jobs_empty_call(): with mock.patch.object( type(client.transport.list_model_deployment_monitoring_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_deployment_monitoring_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11899,6 +14659,9 @@ def test_list_model_deployment_monitoring_jobs_non_empty_request_with_auto_popul with mock.patch.object( type(client.transport.list_model_deployment_monitoring_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_deployment_monitoring_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11909,6 +14672,46 @@ def test_list_model_deployment_monitoring_jobs_non_empty_request_with_auto_popul ) +def test_list_model_deployment_monitoring_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_deployment_monitoring_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_deployment_monitoring_jobs + ] = mock_rpc + request = {} + client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_deployment_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11934,6 +14737,52 @@ async def test_list_model_deployment_monitoring_jobs_empty_call_async(): assert args[0] == job_service.ListModelDeploymentMonitoringJobsRequest() +@pytest.mark.asyncio +async def test_list_model_deployment_monitoring_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_deployment_monitoring_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_deployment_monitoring_jobs + ] = mock_object + + request = {} + await client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_deployment_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_deployment_monitoring_jobs_async( transport: str = "grpc_asyncio", @@ -12378,6 +15227,9 @@ def test_update_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.update_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -12401,12 +15253,59 @@ def test_update_model_deployment_monitoring_job_non_empty_request_with_auto_popu with mock.patch.object( type(client.transport.update_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == job_service.UpdateModelDeploymentMonitoringJobRequest() +def test_update_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -12430,6 +15329,56 @@ async def test_update_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.UpdateModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_update_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -12689,6 +15638,9 @@ def test_delete_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.delete_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -12714,6 +15666,9 @@ def test_delete_model_deployment_monitoring_job_non_empty_request_with_auto_popu with mock.patch.object( type(client.transport.delete_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -12722,6 +15677,50 @@ def test_delete_model_deployment_monitoring_job_non_empty_request_with_auto_popu ) +def test_delete_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -12745,6 +15744,56 @@ async def test_delete_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.DeleteModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_delete_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -12982,6 +16031,9 @@ def test_pause_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.pause_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.pause_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -13007,6 +16059,9 @@ def test_pause_model_deployment_monitoring_job_non_empty_request_with_auto_popul with mock.patch.object( type(client.transport.pause_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.pause_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -13015,6 +16070,46 @@ def test_pause_model_deployment_monitoring_job_non_empty_request_with_auto_popul ) +def test_pause_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.pause_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.pause_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.pause_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.pause_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_pause_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -13036,6 +16131,52 @@ async def test_pause_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.PauseModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_pause_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.pause_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.pause_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.pause_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.pause_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_pause_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -13267,6 +16408,9 @@ def test_resume_model_deployment_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.resume_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.resume_model_deployment_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -13292,6 +16436,9 @@ def test_resume_model_deployment_monitoring_job_non_empty_request_with_auto_popu with mock.patch.object( type(client.transport.resume_model_deployment_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.resume_model_deployment_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -13300,6 +16447,46 @@ def test_resume_model_deployment_monitoring_job_non_empty_request_with_auto_popu ) +def test_resume_model_deployment_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.resume_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.resume_model_deployment_monitoring_job + ] = mock_rpc + request = {} + client.resume_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.resume_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_resume_model_deployment_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -13321,6 +16508,52 @@ async def test_resume_model_deployment_monitoring_job_empty_call_async(): assert args[0] == job_service.ResumeModelDeploymentMonitoringJobRequest() +@pytest.mark.asyncio +async def test_resume_model_deployment_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = JobServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.resume_model_deployment_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.resume_model_deployment_monitoring_job + ] = mock_object + + request = {} + await client.resume_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.resume_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_resume_model_deployment_monitoring_job_async( transport: str = "grpc_asyncio", @@ -13695,6 +16928,44 @@ def get_message_fields(field): assert response.state == job_state.JobState.JOB_STATE_QUEUED +def test_create_custom_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_custom_job + ] = mock_rpc + + request = {} + client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_custom_job_rest_required_fields( request_type=job_service.CreateCustomJobRequest, ): @@ -13977,6 +17248,42 @@ def test_get_custom_job_rest(request_type): assert response.state == job_state.JobState.JOB_STATE_QUEUED +def test_get_custom_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_custom_job] = mock_rpc + + request = {} + client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_custom_job_rest_required_fields( request_type=job_service.GetCustomJobRequest, ): @@ -14244,6 +17551,44 @@ def test_list_custom_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_custom_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_custom_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_custom_jobs + ] = mock_rpc + + request = {} + client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_custom_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_custom_jobs_rest_required_fields( request_type=job_service.ListCustomJobsRequest, ): @@ -14582,6 +17927,48 @@ def test_delete_custom_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_custom_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_custom_job + ] = mock_rpc + + request = {} + client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_custom_job_rest_required_fields( request_type=job_service.DeleteCustomJobRequest, ): @@ -14843,6 +18230,44 @@ def test_cancel_custom_job_rest(request_type): assert response is None +def test_cancel_custom_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_custom_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_custom_job + ] = mock_rpc + + request = {} + client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_custom_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_custom_job_rest_required_fields( request_type=job_service.CancelCustomJobRequest, ): @@ -15234,6 +18659,47 @@ def get_message_fields(field): assert response.specialist_pools == ["specialist_pools_value"] +def test_create_data_labeling_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_data_labeling_job + ] = mock_rpc + + request = {} + client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_data_labeling_job_rest_required_fields( request_type=job_service.CreateDataLabelingJobRequest, ): @@ -15517,17 +18983,58 @@ def test_get_data_labeling_job_rest(request_type): req.return_value = response_value response = client.get_data_labeling_job(request) - # Establish that the response is the type that we expect. - assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.datasets == ["datasets_value"] - assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == "inputs_schema_uri_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + # Establish that the response is the type that we expect. + assert isinstance(response, data_labeling_job.DataLabelingJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.datasets == ["datasets_value"] + assert response.labeler_count == 1375 + assert response.instruction_uri == "instruction_uri_value" + assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.state == job_state.JobState.JOB_STATE_QUEUED + assert response.labeling_progress == 1810 + assert response.specialist_pools == ["specialist_pools_value"] + + +def test_get_data_labeling_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_data_labeling_job + ] = mock_rpc + + request = {} + client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_get_data_labeling_job_rest_required_fields( @@ -15801,6 +19308,47 @@ def test_list_data_labeling_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_data_labeling_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_data_labeling_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_data_labeling_jobs + ] = mock_rpc + + request = {} + client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_data_labeling_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_data_labeling_jobs_rest_required_fields( request_type=job_service.ListDataLabelingJobsRequest, ): @@ -16143,6 +19691,51 @@ def test_delete_data_labeling_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_data_labeling_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_data_labeling_job + ] = mock_rpc + + request = {} + client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_data_labeling_job_rest_required_fields( request_type=job_service.DeleteDataLabelingJobRequest, ): @@ -16408,6 +20001,47 @@ def test_cancel_data_labeling_job_rest(request_type): assert response is None +def test_cancel_data_labeling_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_data_labeling_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_data_labeling_job + ] = mock_rpc + + request = {} + client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_data_labeling_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_data_labeling_job_rest_required_fields( request_type=job_service.CancelDataLabelingJobRequest, ): @@ -16950,6 +20584,47 @@ def get_message_fields(field): assert response.state == job_state.JobState.JOB_STATE_QUEUED +def test_create_hyperparameter_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_hyperparameter_tuning_job + ] = mock_rpc + + request = {} + client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_hyperparameter_tuning_job_rest_required_fields( request_type=job_service.CreateHyperparameterTuningJobRequest, ): @@ -17255,6 +20930,47 @@ def test_get_hyperparameter_tuning_job_rest(request_type): assert response.state == job_state.JobState.JOB_STATE_QUEUED +def test_get_hyperparameter_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_hyperparameter_tuning_job + ] = mock_rpc + + request = {} + client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_hyperparameter_tuning_job_rest_required_fields( request_type=job_service.GetHyperparameterTuningJobRequest, ): @@ -17534,6 +21250,47 @@ def test_list_hyperparameter_tuning_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_hyperparameter_tuning_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_hyperparameter_tuning_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_hyperparameter_tuning_jobs + ] = mock_rpc + + request = {} + client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_hyperparameter_tuning_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_hyperparameter_tuning_jobs_rest_required_fields( request_type=job_service.ListHyperparameterTuningJobsRequest, ): @@ -17887,6 +21644,51 @@ def test_delete_hyperparameter_tuning_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_hyperparameter_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_hyperparameter_tuning_job + ] = mock_rpc + + request = {} + client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_hyperparameter_tuning_job_rest_required_fields( request_type=job_service.DeleteHyperparameterTuningJobRequest, ): @@ -18155,6 +21957,47 @@ def test_cancel_hyperparameter_tuning_job_rest(request_type): assert response is None +def test_cancel_hyperparameter_tuning_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_hyperparameter_tuning_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_hyperparameter_tuning_job + ] = mock_rpc + + request = {} + client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_hyperparameter_tuning_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_hyperparameter_tuning_job_rest_required_fields( request_type=job_service.CancelHyperparameterTuningJobRequest, ): @@ -18614,6 +22457,42 @@ def get_message_fields(field): assert response.enable_restricted_image_training is True +def test_create_nas_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_nas_job] = mock_rpc + + request = {} + client.create_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_nas_job_rest_required_fields( request_type=job_service.CreateNasJobRequest, ): @@ -18896,6 +22775,42 @@ def test_get_nas_job_rest(request_type): assert response.enable_restricted_image_training is True +def test_get_nas_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_nas_job] = mock_rpc + + request = {} + client.get_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_nas_job_rest_required_fields(request_type=job_service.GetNasJobRequest): transport_class = transports.JobServiceRestTransport @@ -19157,6 +23072,42 @@ def test_list_nas_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_nas_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_nas_jobs in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_nas_jobs] = mock_rpc + + request = {} + client.list_nas_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_nas_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_nas_jobs_rest_required_fields( request_type=job_service.ListNasJobsRequest, ): @@ -19491,6 +23442,46 @@ def test_delete_nas_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_nas_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_nas_job] = mock_rpc + + request = {} + client.delete_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_nas_job_rest_required_fields( request_type=job_service.DeleteNasJobRequest, ): @@ -19750,6 +23741,42 @@ def test_cancel_nas_job_rest(request_type): assert response is None +def test_cancel_nas_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.cancel_nas_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.cancel_nas_job] = mock_rpc + + request = {} + client.cancel_nas_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_nas_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_nas_job_rest_required_fields( request_type=job_service.CancelNasJobRequest, ): @@ -20009,6 +24036,46 @@ def test_get_nas_trial_detail_rest(request_type): assert response.parameters == "parameters_value" +def test_get_nas_trial_detail_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_nas_trial_detail in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_nas_trial_detail + ] = mock_rpc + + request = {} + client.get_nas_trial_detail(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_nas_trial_detail(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_nas_trial_detail_rest_required_fields( request_type=job_service.GetNasTrialDetailRequest, ): @@ -20280,6 +24347,47 @@ def test_list_nas_trial_details_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_nas_trial_details_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_nas_trial_details + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_nas_trial_details + ] = mock_rpc + + request = {} + client.list_nas_trial_details(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_nas_trial_details(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_nas_trial_details_rest_required_fields( request_type=job_service.ListNasTrialDetailsRequest, ): @@ -20914,6 +25022,47 @@ def get_message_fields(field): assert response.disable_container_logging is True +def test_create_batch_prediction_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_batch_prediction_job + ] = mock_rpc + + request = {} + client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_batch_prediction_job_rest_required_fields( request_type=job_service.CreateBatchPredictionJobRequest, ): @@ -21212,6 +25361,47 @@ def test_get_batch_prediction_job_rest(request_type): assert response.disable_container_logging is True +def test_get_batch_prediction_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_batch_prediction_job + ] = mock_rpc + + request = {} + client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_batch_prediction_job_rest_required_fields( request_type=job_service.GetBatchPredictionJobRequest, ): @@ -21483,6 +25673,47 @@ def test_list_batch_prediction_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_batch_prediction_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_batch_prediction_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_batch_prediction_jobs + ] = mock_rpc + + request = {} + client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_batch_prediction_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_batch_prediction_jobs_rest_required_fields( request_type=job_service.ListBatchPredictionJobsRequest, ): @@ -21825,6 +26056,51 @@ def test_delete_batch_prediction_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_batch_prediction_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_batch_prediction_job + ] = mock_rpc + + request = {} + client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_batch_prediction_job_rest_required_fields( request_type=job_service.DeleteBatchPredictionJobRequest, ): @@ -22090,6 +26366,47 @@ def test_cancel_batch_prediction_job_rest(request_type): assert response is None +def test_cancel_batch_prediction_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_batch_prediction_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_batch_prediction_job + ] = mock_rpc + + request = {} + client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_batch_prediction_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_batch_prediction_job_rest_required_fields( request_type=job_service.CancelBatchPredictionJobRequest, ): @@ -22544,6 +26861,47 @@ def get_message_fields(field): assert response.enable_monitoring_pipeline_logs is True +def test_create_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.create_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.CreateModelDeploymentMonitoringJobRequest, ): @@ -22861,6 +27219,47 @@ def test_search_model_deployment_monitoring_stats_anomalies_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_model_deployment_monitoring_stats_anomalies_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_model_deployment_monitoring_stats_anomalies + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_model_deployment_monitoring_stats_anomalies + ] = mock_rpc + + request = {} + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_model_deployment_monitoring_stats_anomalies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_model_deployment_monitoring_stats_anomalies_rest_required_fields( request_type=job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, ): @@ -23280,6 +27679,47 @@ def test_get_model_deployment_monitoring_job_rest(request_type): assert response.enable_monitoring_pipeline_logs is True +def test_get_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.get_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.GetModelDeploymentMonitoringJobRequest, ): @@ -23568,6 +28008,47 @@ def test_list_model_deployment_monitoring_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_model_deployment_monitoring_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_deployment_monitoring_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_deployment_monitoring_jobs + ] = mock_rpc + + request = {} + client.list_model_deployment_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_deployment_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_model_deployment_monitoring_jobs_rest_required_fields( request_type=job_service.ListModelDeploymentMonitoringJobsRequest, ): @@ -24105,6 +28586,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.update_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.UpdateModelDeploymentMonitoringJobRequest, ): @@ -24397,6 +28923,51 @@ def test_delete_model_deployment_monitoring_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.delete_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.DeleteModelDeploymentMonitoringJobRequest, ): @@ -24673,6 +29244,47 @@ def test_pause_model_deployment_monitoring_job_rest(request_type): assert response is None +def test_pause_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.pause_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.pause_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.pause_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.pause_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_pause_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.PauseModelDeploymentMonitoringJobRequest, ): @@ -24939,6 +29551,47 @@ def test_resume_model_deployment_monitoring_job_rest(request_type): assert response is None +def test_resume_model_deployment_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = JobServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.resume_model_deployment_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.resume_model_deployment_monitoring_job + ] = mock_rpc + + request = {} + client.resume_model_deployment_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.resume_model_deployment_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_resume_model_deployment_monitoring_job_rest_required_fields( request_type=job_service.ResumeModelDeploymentMonitoringJobRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_llm_utility_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_llm_utility_service.py index 14e4ccf990..a7f13ce566 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_llm_utility_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_llm_utility_service.py @@ -1209,6 +1209,9 @@ def test_compute_tokens_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.compute_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.compute_tokens() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1232,6 +1235,9 @@ def test_compute_tokens_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.compute_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.compute_tokens(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1240,6 +1246,41 @@ def test_compute_tokens_non_empty_request_with_auto_populated_field(): ) +def test_compute_tokens_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = LlmUtilityServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.compute_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.compute_tokens] = mock_rpc + request = {} + client.compute_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.compute_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_compute_tokens_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1261,6 +1302,52 @@ async def test_compute_tokens_empty_call_async(): assert args[0] == llm_utility_service.ComputeTokensRequest() +@pytest.mark.asyncio +async def test_compute_tokens_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = LlmUtilityServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.compute_tokens + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.compute_tokens + ] = mock_object + + request = {} + await client.compute_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.compute_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_compute_tokens_async( transport: str = "grpc_asyncio", @@ -1488,6 +1575,42 @@ def test_compute_tokens_rest(request_type): assert isinstance(response, llm_utility_service.ComputeTokensResponse) +def test_compute_tokens_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = LlmUtilityServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.compute_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.compute_tokens] = mock_rpc + + request = {} + client.compute_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.compute_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_compute_tokens_rest_required_fields( request_type=llm_utility_service.ComputeTokensRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_match_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_match_service.py index 422b35e89b..6263538415 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_match_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_match_service.py @@ -1146,6 +1146,9 @@ def test_find_neighbors_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.find_neighbors), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.find_neighbors() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1170,6 +1173,9 @@ def test_find_neighbors_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.find_neighbors), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.find_neighbors(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1179,6 +1185,41 @@ def test_find_neighbors_non_empty_request_with_auto_populated_field(): ) +def test_find_neighbors_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MatchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.find_neighbors in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.find_neighbors] = mock_rpc + request = {} + client.find_neighbors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.find_neighbors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_find_neighbors_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1200,6 +1241,52 @@ async def test_find_neighbors_empty_call_async(): assert args[0] == match_service.FindNeighborsRequest() +@pytest.mark.asyncio +async def test_find_neighbors_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MatchServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.find_neighbors + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.find_neighbors + ] = mock_object + + request = {} + await client.find_neighbors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.find_neighbors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_find_neighbors_async( transport: str = "grpc_asyncio", request_type=match_service.FindNeighborsRequest @@ -1344,6 +1431,9 @@ def test_read_index_datapoints_empty_call(): with mock.patch.object( type(client.transport.read_index_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_index_datapoints() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1370,6 +1460,9 @@ def test_read_index_datapoints_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.read_index_datapoints), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_index_datapoints(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1379,6 +1472,46 @@ def test_read_index_datapoints_non_empty_request_with_auto_populated_field(): ) +def test_read_index_datapoints_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MatchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_index_datapoints + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_index_datapoints + ] = mock_rpc + request = {} + client.read_index_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_index_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_index_datapoints_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1402,6 +1535,52 @@ async def test_read_index_datapoints_empty_call_async(): assert args[0] == match_service.ReadIndexDatapointsRequest() +@pytest.mark.asyncio +async def test_read_index_datapoints_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MatchServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_index_datapoints + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_index_datapoints + ] = mock_object + + request = {} + await client.read_index_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_index_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_index_datapoints_async( transport: str = "grpc_asyncio", @@ -1545,6 +1724,42 @@ def test_find_neighbors_rest(request_type): assert isinstance(response, match_service.FindNeighborsResponse) +def test_find_neighbors_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MatchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.find_neighbors in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.find_neighbors] = mock_rpc + + request = {} + client.find_neighbors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.find_neighbors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_find_neighbors_rest_required_fields( request_type=match_service.FindNeighborsRequest, ): @@ -1757,6 +1972,47 @@ def test_read_index_datapoints_rest(request_type): assert isinstance(response, match_service.ReadIndexDatapointsResponse) +def test_read_index_datapoints_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MatchServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_index_datapoints + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_index_datapoints + ] = mock_rpc + + request = {} + client.read_index_datapoints(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_index_datapoints(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_index_datapoints_rest_required_fields( request_type=match_service.ReadIndexDatapointsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py index d15838638e..c65c2db2a5 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -1221,6 +1221,9 @@ def test_create_metadata_store_empty_call(): with mock.patch.object( type(client.transport.create_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1247,6 +1250,9 @@ def test_create_metadata_store_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_metadata_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1256,6 +1262,50 @@ def test_create_metadata_store_non_empty_request_with_auto_populated_field(): ) +def test_create_metadata_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_metadata_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_metadata_store + ] = mock_rpc + request = {} + client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_metadata_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1279,6 +1329,56 @@ async def test_create_metadata_store_empty_call_async(): assert args[0] == metadata_service.CreateMetadataStoreRequest() +@pytest.mark.asyncio +async def test_create_metadata_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_metadata_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_metadata_store + ] = mock_object + + request = {} + await client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_metadata_store_async( transport: str = "grpc_asyncio", @@ -1541,6 +1641,9 @@ def test_get_metadata_store_empty_call(): with mock.patch.object( type(client.transport.get_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1566,6 +1669,9 @@ def test_get_metadata_store_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_metadata_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1574,6 +1680,45 @@ def test_get_metadata_store_non_empty_request_with_auto_populated_field(): ) +def test_get_metadata_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_metadata_store in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_metadata_store + ] = mock_rpc + request = {} + client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_metadata_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1600,6 +1745,52 @@ async def test_get_metadata_store_empty_call_async(): assert args[0] == metadata_service.GetMetadataStoreRequest() +@pytest.mark.asyncio +async def test_get_metadata_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_metadata_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_metadata_store + ] = mock_object + + request = {} + await client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_metadata_store_async( transport: str = "grpc_asyncio", @@ -1845,6 +2036,9 @@ def test_list_metadata_stores_empty_call(): with mock.patch.object( type(client.transport.list_metadata_stores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_metadata_stores() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1871,6 +2065,9 @@ def test_list_metadata_stores_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_metadata_stores), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_metadata_stores(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1880,6 +2077,45 @@ def test_list_metadata_stores_non_empty_request_with_auto_populated_field(): ) +def test_list_metadata_stores_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_metadata_stores in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_metadata_stores + ] = mock_rpc + request = {} + client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_metadata_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_metadata_stores_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1905,6 +2141,52 @@ async def test_list_metadata_stores_empty_call_async(): assert args[0] == metadata_service.ListMetadataStoresRequest() +@pytest.mark.asyncio +async def test_list_metadata_stores_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_metadata_stores + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_metadata_stores + ] = mock_object + + request = {} + await client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_metadata_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_metadata_stores_async( transport: str = "grpc_asyncio", @@ -2343,6 +2625,9 @@ def test_delete_metadata_store_empty_call(): with mock.patch.object( type(client.transport.delete_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_metadata_store() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2368,6 +2653,9 @@ def test_delete_metadata_store_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_metadata_store), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_metadata_store(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2376,6 +2664,50 @@ def test_delete_metadata_store_non_empty_request_with_auto_populated_field(): ) +def test_delete_metadata_store_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_metadata_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_metadata_store + ] = mock_rpc + request = {} + client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_metadata_store_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2399,6 +2731,56 @@ async def test_delete_metadata_store_empty_call_async(): assert args[0] == metadata_service.DeleteMetadataStoreRequest() +@pytest.mark.asyncio +async def test_delete_metadata_store_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_metadata_store + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_metadata_store + ] = mock_object + + request = {} + await client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_metadata_store_async( transport: str = "grpc_asyncio", @@ -2649,6 +3031,9 @@ def test_create_artifact_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_artifact() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2673,6 +3058,9 @@ def test_create_artifact_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_artifact(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2682,6 +3070,41 @@ def test_create_artifact_non_empty_request_with_auto_populated_field(): ) +def test_create_artifact_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_artifact] = mock_rpc + request = {} + client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_artifact_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2712,6 +3135,52 @@ async def test_create_artifact_empty_call_async(): assert args[0] == metadata_service.CreateArtifactRequest() +@pytest.mark.asyncio +async def test_create_artifact_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_artifact + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_artifact + ] = mock_object + + request = {} + await client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_artifact_async( transport: str = "grpc_asyncio", request_type=metadata_service.CreateArtifactRequest @@ -2988,6 +3457,9 @@ def test_get_artifact_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_artifact() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3011,6 +3483,9 @@ def test_get_artifact_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_artifact(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3019,6 +3494,41 @@ def test_get_artifact_non_empty_request_with_auto_populated_field(): ) +def test_get_artifact_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_artifact] = mock_rpc + request = {} + client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_artifact_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3049,6 +3559,52 @@ async def test_get_artifact_empty_call_async(): assert args[0] == metadata_service.GetArtifactRequest() +@pytest.mark.asyncio +async def test_get_artifact_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_artifact + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_artifact + ] = mock_object + + request = {} + await client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_artifact_async( transport: str = "grpc_asyncio", request_type=metadata_service.GetArtifactRequest @@ -3287,6 +3843,9 @@ def test_list_artifacts_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_artifacts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3313,6 +3872,9 @@ def test_list_artifacts_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_artifacts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3324,18 +3886,53 @@ def test_list_artifacts_non_empty_request_with_auto_populated_field(): ) -@pytest.mark.asyncio -async def test_list_artifacts_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) +def test_list_artifacts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: - # Designate an appropriate return value for the call. + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_artifacts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_artifacts] = mock_rpc + request = {} + client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_artifacts_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_artifacts), "__call__") as call: + # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( metadata_service.ListArtifactsResponse( next_page_token="next_page_token_value", @@ -3347,6 +3944,52 @@ async def test_list_artifacts_empty_call_async(): assert args[0] == metadata_service.ListArtifactsRequest() +@pytest.mark.asyncio +async def test_list_artifacts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_artifacts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_artifacts + ] = mock_object + + request = {} + await client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_artifacts_async( transport: str = "grpc_asyncio", request_type=metadata_service.ListArtifactsRequest @@ -3779,6 +4422,9 @@ def test_update_artifact_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_artifact() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3800,12 +4446,50 @@ def test_update_artifact_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_artifact(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateArtifactRequest() +def test_update_artifact_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_artifact] = mock_rpc + request = {} + client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_artifact_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3836,6 +4520,52 @@ async def test_update_artifact_empty_call_async(): assert args[0] == metadata_service.UpdateArtifactRequest() +@pytest.mark.asyncio +async def test_update_artifact_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_artifact + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_artifact + ] = mock_object + + request = {} + await client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_artifact_async( transport: str = "grpc_asyncio", request_type=metadata_service.UpdateArtifactRequest @@ -4085,6 +4815,9 @@ def test_delete_artifact_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_artifact() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4109,6 +4842,9 @@ def test_delete_artifact_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_artifact(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4118,6 +4854,45 @@ def test_delete_artifact_non_empty_request_with_auto_populated_field(): ) +def test_delete_artifact_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_artifact] = mock_rpc + request = {} + client.delete_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_artifact_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4139,6 +4914,56 @@ async def test_delete_artifact_empty_call_async(): assert args[0] == metadata_service.DeleteArtifactRequest() +@pytest.mark.asyncio +async def test_delete_artifact_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_artifact + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_artifact + ] = mock_object + + request = {} + await client.delete_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_artifact_async( transport: str = "grpc_asyncio", request_type=metadata_service.DeleteArtifactRequest @@ -4361,6 +5186,9 @@ def test_purge_artifacts_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_artifacts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4385,6 +5213,9 @@ def test_purge_artifacts_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_artifacts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4394,6 +5225,45 @@ def test_purge_artifacts_non_empty_request_with_auto_populated_field(): ) +def test_purge_artifacts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_artifacts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.purge_artifacts] = mock_rpc + request = {} + client.purge_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_purge_artifacts_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4415,6 +5285,56 @@ async def test_purge_artifacts_empty_call_async(): assert args[0] == metadata_service.PurgeArtifactsRequest() +@pytest.mark.asyncio +async def test_purge_artifacts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.purge_artifacts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.purge_artifacts + ] = mock_object + + request = {} + await client.purge_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.purge_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_purge_artifacts_async( transport: str = "grpc_asyncio", request_type=metadata_service.PurgeArtifactsRequest @@ -4652,6 +5572,9 @@ def test_create_context_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_context() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4676,6 +5599,9 @@ def test_create_context_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_context(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4685,6 +5611,41 @@ def test_create_context_non_empty_request_with_auto_populated_field(): ) +def test_create_context_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_context] = mock_rpc + request = {} + client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_context_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4714,6 +5675,52 @@ async def test_create_context_empty_call_async(): assert args[0] == metadata_service.CreateContextRequest() +@pytest.mark.asyncio +async def test_create_context_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_context + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_context + ] = mock_object + + request = {} + await client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_context_async( transport: str = "grpc_asyncio", request_type=metadata_service.CreateContextRequest @@ -4982,6 +5989,9 @@ def test_get_context_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_context() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5005,6 +6015,9 @@ def test_get_context_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_context(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5013,6 +6026,41 @@ def test_get_context_non_empty_request_with_auto_populated_field(): ) +def test_get_context_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_context] = mock_rpc + request = {} + client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_context_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5042,6 +6090,52 @@ async def test_get_context_empty_call_async(): assert args[0] == metadata_service.GetContextRequest() +@pytest.mark.asyncio +async def test_get_context_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_context + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_context + ] = mock_object + + request = {} + await client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_context_async( transport: str = "grpc_asyncio", request_type=metadata_service.GetContextRequest @@ -5278,6 +6372,9 @@ def test_list_contexts_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_contexts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5304,6 +6401,9 @@ def test_list_contexts_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_contexts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5315,6 +6415,41 @@ def test_list_contexts_non_empty_request_with_auto_populated_field(): ) +def test_list_contexts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_contexts] = mock_rpc + request = {} + client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_contexts_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5339,7 +6474,53 @@ async def test_list_contexts_empty_call_async(): @pytest.mark.asyncio -async def test_list_contexts_async( +async def test_list_contexts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_contexts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_contexts + ] = mock_object + + request = {} + await client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_contexts_async( transport: str = "grpc_asyncio", request_type=metadata_service.ListContextsRequest ): client = MetadataServiceAsyncClient( @@ -5768,6 +6949,9 @@ def test_update_context_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_context() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5789,12 +6973,50 @@ def test_update_context_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_context(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateContextRequest() +def test_update_context_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_context] = mock_rpc + request = {} + client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_context_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5824,6 +7046,52 @@ async def test_update_context_empty_call_async(): assert args[0] == metadata_service.UpdateContextRequest() +@pytest.mark.asyncio +async def test_update_context_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_context + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_context + ] = mock_object + + request = {} + await client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_context_async( transport: str = "grpc_asyncio", request_type=metadata_service.UpdateContextRequest @@ -6067,6 +7335,9 @@ def test_delete_context_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_context() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6091,6 +7362,9 @@ def test_delete_context_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_context(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6100,6 +7374,45 @@ def test_delete_context_non_empty_request_with_auto_populated_field(): ) +def test_delete_context_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_context] = mock_rpc + request = {} + client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_context_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6121,6 +7434,56 @@ async def test_delete_context_empty_call_async(): assert args[0] == metadata_service.DeleteContextRequest() +@pytest.mark.asyncio +async def test_delete_context_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_context + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_context + ] = mock_object + + request = {} + await client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_context_async( transport: str = "grpc_asyncio", request_type=metadata_service.DeleteContextRequest @@ -6343,6 +7706,9 @@ def test_purge_contexts_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_contexts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6367,6 +7733,9 @@ def test_purge_contexts_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_contexts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6376,6 +7745,45 @@ def test_purge_contexts_non_empty_request_with_auto_populated_field(): ) +def test_purge_contexts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.purge_contexts] = mock_rpc + request = {} + client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_purge_contexts_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6397,6 +7805,56 @@ async def test_purge_contexts_empty_call_async(): assert args[0] == metadata_service.PurgeContextsRequest() +@pytest.mark.asyncio +async def test_purge_contexts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.purge_contexts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.purge_contexts + ] = mock_object + + request = {} + await client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.purge_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_purge_contexts_async( transport: str = "grpc_asyncio", request_type=metadata_service.PurgeContextsRequest @@ -6625,6 +8083,9 @@ def test_add_context_artifacts_and_executions_empty_call(): with mock.patch.object( type(client.transport.add_context_artifacts_and_executions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_context_artifacts_and_executions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6650,6 +8111,9 @@ def test_add_context_artifacts_and_executions_non_empty_request_with_auto_popula with mock.patch.object( type(client.transport.add_context_artifacts_and_executions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_context_artifacts_and_executions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6658,6 +8122,46 @@ def test_add_context_artifacts_and_executions_non_empty_request_with_auto_popula ) +def test_add_context_artifacts_and_executions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_context_artifacts_and_executions + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_context_artifacts_and_executions + ] = mock_rpc + request = {} + client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_context_artifacts_and_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6681,6 +8185,52 @@ async def test_add_context_artifacts_and_executions_empty_call_async(): assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.add_context_artifacts_and_executions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.add_context_artifacts_and_executions + ] = mock_object + + request = {} + await client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.add_context_artifacts_and_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_add_context_artifacts_and_executions_async( transport: str = "grpc_asyncio", @@ -6940,6 +8490,9 @@ def test_add_context_children_empty_call(): with mock.patch.object( type(client.transport.add_context_children), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_context_children() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6965,6 +8518,9 @@ def test_add_context_children_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.add_context_children), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_context_children(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6973,6 +8529,45 @@ def test_add_context_children_non_empty_request_with_auto_populated_field(): ) +def test_add_context_children_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_context_children in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_context_children + ] = mock_rpc + request = {} + client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_add_context_children_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6996,6 +8591,52 @@ async def test_add_context_children_empty_call_async(): assert args[0] == metadata_service.AddContextChildrenRequest() +@pytest.mark.asyncio +async def test_add_context_children_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.add_context_children + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.add_context_children + ] = mock_object + + request = {} + await client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.add_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_add_context_children_async( transport: str = "grpc_asyncio", @@ -7243,6 +8884,9 @@ def test_remove_context_children_empty_call(): with mock.patch.object( type(client.transport.remove_context_children), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.remove_context_children() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7268,6 +8912,9 @@ def test_remove_context_children_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.remove_context_children), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.remove_context_children(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7276,6 +8923,46 @@ def test_remove_context_children_non_empty_request_with_auto_populated_field(): ) +def test_remove_context_children_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.remove_context_children + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.remove_context_children + ] = mock_rpc + request = {} + client.remove_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.remove_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_remove_context_children_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7300,24 +8987,70 @@ async def test_remove_context_children_empty_call_async(): @pytest.mark.asyncio -async def test_remove_context_children_async( +async def test_remove_context_children_async_use_cached_wrapped_rpc( transport: str = "grpc_asyncio", - request_type=metadata_service.RemoveContextChildrenRequest, ): - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.remove_context_children), "__call__" - ) as call: - # Designate an appropriate return value for the call. + # Ensure method has been cached + assert ( + client._client._transport.remove_context_children + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.remove_context_children + ] = mock_object + + request = {} + await client.remove_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.remove_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_remove_context_children_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.RemoveContextChildrenRequest, +): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.remove_context_children), "__call__" + ) as call: + # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( metadata_service.RemoveContextChildrenResponse() ) @@ -7546,6 +9279,9 @@ def test_query_context_lineage_subgraph_empty_call(): with mock.patch.object( type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_context_lineage_subgraph() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7571,6 +9307,9 @@ def test_query_context_lineage_subgraph_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_context_lineage_subgraph(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7579,6 +9318,46 @@ def test_query_context_lineage_subgraph_non_empty_request_with_auto_populated_fi ) +def test_query_context_lineage_subgraph_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_context_lineage_subgraph + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_context_lineage_subgraph + ] = mock_rpc + request = {} + client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_context_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_context_lineage_subgraph_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7602,6 +9381,52 @@ async def test_query_context_lineage_subgraph_empty_call_async(): assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() +@pytest.mark.asyncio +async def test_query_context_lineage_subgraph_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_context_lineage_subgraph + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_context_lineage_subgraph + ] = mock_object + + request = {} + await client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_context_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_query_context_lineage_subgraph_async( transport: str = "grpc_asyncio", @@ -7850,6 +9675,9 @@ def test_create_execution_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_execution() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7874,6 +9702,9 @@ def test_create_execution_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_execution(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7883,6 +9714,43 @@ def test_create_execution_non_empty_request_with_auto_populated_field(): ) +def test_create_execution_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_execution + ] = mock_rpc + request = {} + client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_execution_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7912,6 +9780,52 @@ async def test_create_execution_empty_call_async(): assert args[0] == metadata_service.CreateExecutionRequest() +@pytest.mark.asyncio +async def test_create_execution_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_execution + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_execution + ] = mock_object + + request = {} + await client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_execution_async( transport: str = "grpc_asyncio", @@ -8185,6 +10099,9 @@ def test_get_execution_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_execution() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8208,6 +10125,9 @@ def test_get_execution_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_execution(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8216,6 +10136,41 @@ def test_get_execution_non_empty_request_with_auto_populated_field(): ) +def test_get_execution_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_execution] = mock_rpc + request = {} + client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_execution_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8245,6 +10200,52 @@ async def test_get_execution_empty_call_async(): assert args[0] == metadata_service.GetExecutionRequest() +@pytest.mark.asyncio +async def test_get_execution_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_execution + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_execution + ] = mock_object + + request = {} + await client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_execution_async( transport: str = "grpc_asyncio", request_type=metadata_service.GetExecutionRequest @@ -8481,6 +10482,9 @@ def test_list_executions_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_executions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8507,6 +10511,9 @@ def test_list_executions_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_executions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8518,6 +10525,41 @@ def test_list_executions_non_empty_request_with_auto_populated_field(): ) +def test_list_executions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_executions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_executions] = mock_rpc + request = {} + client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_executions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8541,6 +10583,52 @@ async def test_list_executions_empty_call_async(): assert args[0] == metadata_service.ListExecutionsRequest() +@pytest.mark.asyncio +async def test_list_executions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_executions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_executions + ] = mock_object + + request = {} + await client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_executions_async( transport: str = "grpc_asyncio", request_type=metadata_service.ListExecutionsRequest @@ -8971,6 +11059,9 @@ def test_update_execution_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_execution() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8992,12 +11083,52 @@ def test_update_execution_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_execution(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == metadata_service.UpdateExecutionRequest() +def test_update_execution_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_execution + ] = mock_rpc + request = {} + client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_execution_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9027,6 +11158,52 @@ async def test_update_execution_empty_call_async(): assert args[0] == metadata_service.UpdateExecutionRequest() +@pytest.mark.asyncio +async def test_update_execution_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_execution + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_execution + ] = mock_object + + request = {} + await client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_execution_async( transport: str = "grpc_asyncio", @@ -9275,6 +11452,9 @@ def test_delete_execution_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_execution() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9299,6 +11479,9 @@ def test_delete_execution_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_execution(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9308,6 +11491,47 @@ def test_delete_execution_non_empty_request_with_auto_populated_field(): ) +def test_delete_execution_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_execution + ] = mock_rpc + request = {} + client.delete_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_execution_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9329,6 +11553,56 @@ async def test_delete_execution_empty_call_async(): assert args[0] == metadata_service.DeleteExecutionRequest() +@pytest.mark.asyncio +async def test_delete_execution_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_execution + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_execution + ] = mock_object + + request = {} + await client.delete_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_execution_async( transport: str = "grpc_asyncio", @@ -9552,6 +11826,9 @@ def test_purge_executions_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_executions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9576,6 +11853,9 @@ def test_purge_executions_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.purge_executions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9585,6 +11865,47 @@ def test_purge_executions_non_empty_request_with_auto_populated_field(): ) +def test_purge_executions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_executions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.purge_executions + ] = mock_rpc + request = {} + client.purge_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_purge_executions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9606,6 +11927,56 @@ async def test_purge_executions_empty_call_async(): assert args[0] == metadata_service.PurgeExecutionsRequest() +@pytest.mark.asyncio +async def test_purge_executions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.purge_executions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.purge_executions + ] = mock_object + + request = {} + await client.purge_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.purge_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_purge_executions_async( transport: str = "grpc_asyncio", @@ -9833,6 +12204,9 @@ def test_add_execution_events_empty_call(): with mock.patch.object( type(client.transport.add_execution_events), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_execution_events() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9858,6 +12232,9 @@ def test_add_execution_events_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.add_execution_events), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_execution_events(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9866,6 +12243,45 @@ def test_add_execution_events_non_empty_request_with_auto_populated_field(): ) +def test_add_execution_events_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_execution_events in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_execution_events + ] = mock_rpc + request = {} + client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_execution_events(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_add_execution_events_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9889,6 +12305,52 @@ async def test_add_execution_events_empty_call_async(): assert args[0] == metadata_service.AddExecutionEventsRequest() +@pytest.mark.asyncio +async def test_add_execution_events_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.add_execution_events + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.add_execution_events + ] = mock_object + + request = {} + await client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.add_execution_events(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_add_execution_events_async( transport: str = "grpc_asyncio", @@ -10136,6 +12598,9 @@ def test_query_execution_inputs_and_outputs_empty_call(): with mock.patch.object( type(client.transport.query_execution_inputs_and_outputs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_execution_inputs_and_outputs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10161,6 +12626,9 @@ def test_query_execution_inputs_and_outputs_non_empty_request_with_auto_populate with mock.patch.object( type(client.transport.query_execution_inputs_and_outputs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_execution_inputs_and_outputs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10169,6 +12637,46 @@ def test_query_execution_inputs_and_outputs_non_empty_request_with_auto_populate ) +def test_query_execution_inputs_and_outputs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_execution_inputs_and_outputs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_execution_inputs_and_outputs + ] = mock_rpc + request = {} + client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_execution_inputs_and_outputs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_execution_inputs_and_outputs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10178,18 +12686,64 @@ async def test_query_execution_inputs_and_outputs_empty_call_async(): transport="grpc_asyncio", ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_execution_inputs_and_outputs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - lineage_subgraph.LineageSubgraph() - ) - response = await client.query_execution_inputs_and_outputs() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.query_execution_inputs_and_outputs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + lineage_subgraph.LineageSubgraph() + ) + response = await client.query_execution_inputs_and_outputs() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.QueryExecutionInputsAndOutputsRequest() + + +@pytest.mark.asyncio +async def test_query_execution_inputs_and_outputs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_execution_inputs_and_outputs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_execution_inputs_and_outputs + ] = mock_object + + request = {} + await client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_execution_inputs_and_outputs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio @@ -10443,6 +12997,9 @@ def test_create_metadata_schema_empty_call(): with mock.patch.object( type(client.transport.create_metadata_schema), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_metadata_schema() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10469,6 +13026,9 @@ def test_create_metadata_schema_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_metadata_schema), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_metadata_schema(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10478,6 +13038,46 @@ def test_create_metadata_schema_non_empty_request_with_auto_populated_field(): ) +def test_create_metadata_schema_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_metadata_schema + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_metadata_schema + ] = mock_rpc + request = {} + client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_metadata_schema_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10507,6 +13107,52 @@ async def test_create_metadata_schema_empty_call_async(): assert args[0] == metadata_service.CreateMetadataSchemaRequest() +@pytest.mark.asyncio +async def test_create_metadata_schema_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_metadata_schema + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_metadata_schema + ] = mock_object + + request = {} + await client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_metadata_schema_async( transport: str = "grpc_asyncio", @@ -10792,6 +13438,9 @@ def test_get_metadata_schema_empty_call(): with mock.patch.object( type(client.transport.get_metadata_schema), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_metadata_schema() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10817,6 +13466,9 @@ def test_get_metadata_schema_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_metadata_schema), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_metadata_schema(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10825,6 +13477,45 @@ def test_get_metadata_schema_non_empty_request_with_auto_populated_field(): ) +def test_get_metadata_schema_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_metadata_schema in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_metadata_schema + ] = mock_rpc + request = {} + client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_metadata_schema_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10854,6 +13545,52 @@ async def test_get_metadata_schema_empty_call_async(): assert args[0] == metadata_service.GetMetadataSchemaRequest() +@pytest.mark.asyncio +async def test_get_metadata_schema_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_metadata_schema + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_metadata_schema + ] = mock_object + + request = {} + await client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_metadata_schema_async( transport: str = "grpc_asyncio", @@ -11108,6 +13845,9 @@ def test_list_metadata_schemas_empty_call(): with mock.patch.object( type(client.transport.list_metadata_schemas), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_metadata_schemas() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11135,6 +13875,9 @@ def test_list_metadata_schemas_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_metadata_schemas), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_metadata_schemas(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11145,6 +13888,46 @@ def test_list_metadata_schemas_non_empty_request_with_auto_populated_field(): ) +def test_list_metadata_schemas_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_metadata_schemas + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_metadata_schemas + ] = mock_rpc + request = {} + client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_metadata_schemas(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_metadata_schemas_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11170,6 +13953,52 @@ async def test_list_metadata_schemas_empty_call_async(): assert args[0] == metadata_service.ListMetadataSchemasRequest() +@pytest.mark.asyncio +async def test_list_metadata_schemas_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_metadata_schemas + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_metadata_schemas + ] = mock_object + + request = {} + await client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_metadata_schemas(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_metadata_schemas_async( transport: str = "grpc_asyncio", @@ -11608,6 +14437,9 @@ def test_query_artifact_lineage_subgraph_empty_call(): with mock.patch.object( type(client.transport.query_artifact_lineage_subgraph), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_artifact_lineage_subgraph() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11634,6 +14466,9 @@ def test_query_artifact_lineage_subgraph_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.query_artifact_lineage_subgraph), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_artifact_lineage_subgraph(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11643,6 +14478,46 @@ def test_query_artifact_lineage_subgraph_non_empty_request_with_auto_populated_f ) +def test_query_artifact_lineage_subgraph_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_artifact_lineage_subgraph + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_artifact_lineage_subgraph + ] = mock_rpc + request = {} + client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_artifact_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_artifact_lineage_subgraph_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11666,6 +14541,52 @@ async def test_query_artifact_lineage_subgraph_empty_call_async(): assert args[0] == metadata_service.QueryArtifactLineageSubgraphRequest() +@pytest.mark.asyncio +async def test_query_artifact_lineage_subgraph_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_artifact_lineage_subgraph + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_artifact_lineage_subgraph + ] = mock_object + + request = {} + await client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_artifact_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_query_artifact_lineage_subgraph_async( transport: str = "grpc_asyncio", @@ -11968,6 +14889,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_metadata_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_metadata_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_metadata_store + ] = mock_rpc + + request = {} + client.create_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_metadata_store_rest_required_fields( request_type=metadata_service.CreateMetadataStoreRequest, ): @@ -12249,6 +15215,46 @@ def test_get_metadata_store_rest(request_type): assert response.description == "description_value" +def test_get_metadata_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_metadata_store in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_metadata_store + ] = mock_rpc + + request = {} + client.get_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_metadata_store(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_metadata_store_rest_required_fields( request_type=metadata_service.GetMetadataStoreRequest, ): @@ -12518,6 +15524,46 @@ def test_list_metadata_stores_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_metadata_stores_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_metadata_stores in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_metadata_stores + ] = mock_rpc + + request = {} + client.list_metadata_stores(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_metadata_stores(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_metadata_stores_rest_required_fields( request_type=metadata_service.ListMetadataStoresRequest, ): @@ -12834,22 +15880,67 @@ def test_delete_metadata_store_rest(request_type): request_init = {"name": "projects/sample1/locations/sample2/metadataStores/sample3"} request = request_type(**request_init) - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), "request") as req: - # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_metadata_store(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_delete_metadata_store_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_metadata_store + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_metadata_store + ] = mock_rpc + + request = {} + client.delete_metadata_store(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - json_return_value = json_format.MessageToJson(return_value) + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.delete_metadata_store(request) + client.delete_metadata_store(request) - # Establish that the response is the type that we expect. - assert response.operation.name == "operations/spam" + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_delete_metadata_store_rest_required_fields( @@ -13217,6 +16308,42 @@ def get_message_fields(field): assert response.description == "description_value" +def test_create_artifact_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_artifact] = mock_rpc + + request = {} + client.create_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_artifact_rest_required_fields( request_type=metadata_service.CreateArtifactRequest, ): @@ -13519,6 +16646,42 @@ def test_get_artifact_rest(request_type): assert response.description == "description_value" +def test_get_artifact_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_artifact] = mock_rpc + + request = {} + client.get_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_artifact_rest_required_fields( request_type=metadata_service.GetArtifactRequest, ): @@ -13790,6 +16953,42 @@ def test_list_artifacts_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_artifacts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_artifacts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_artifacts] = mock_rpc + + request = {} + client.list_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_artifacts_rest_required_fields( request_type=metadata_service.ListArtifactsRequest, ): @@ -14238,6 +17437,42 @@ def get_message_fields(field): assert response.description == "description_value" +def test_update_artifact_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_artifact] = mock_rpc + + request = {} + client.update_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_artifact_rest_required_fields( request_type=metadata_service.UpdateArtifactRequest, ): @@ -14523,6 +17758,46 @@ def test_delete_artifact_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_artifact_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_artifact in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_artifact] = mock_rpc + + request = {} + client.delete_artifact(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_artifact(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_artifact_rest_required_fields( request_type=metadata_service.DeleteArtifactRequest, ): @@ -14790,6 +18065,46 @@ def test_purge_artifacts_rest(request_type): assert response.operation.name == "operations/spam" +def test_purge_artifacts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_artifacts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.purge_artifacts] = mock_rpc + + request = {} + client.purge_artifacts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_artifacts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_purge_artifacts_rest_required_fields( request_type=metadata_service.PurgeArtifactsRequest, ): @@ -15165,6 +18480,42 @@ def get_message_fields(field): assert response.description == "description_value" +def test_create_context_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_context] = mock_rpc + + request = {} + client.create_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_context_rest_required_fields( request_type=metadata_service.CreateContextRequest, ): @@ -15463,6 +18814,42 @@ def test_get_context_rest(request_type): assert response.description == "description_value" +def test_get_context_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_context] = mock_rpc + + request = {} + client.get_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_context_rest_required_fields( request_type=metadata_service.GetContextRequest, ): @@ -15734,6 +19121,42 @@ def test_list_contexts_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_contexts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_contexts] = mock_rpc + + request = {} + client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_contexts_rest_required_fields( request_type=metadata_service.ListContextsRequest, ): @@ -16179,6 +19602,42 @@ def get_message_fields(field): assert response.description == "description_value" +def test_update_context_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_context] = mock_rpc + + request = {} + client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_context_rest_required_fields( request_type=metadata_service.UpdateContextRequest, ): @@ -16462,6 +19921,46 @@ def test_delete_context_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_context_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_context in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_context] = mock_rpc + + request = {} + client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_context(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_context_rest_required_fields( request_type=metadata_service.DeleteContextRequest, ): @@ -16742,6 +20241,46 @@ def test_purge_contexts_rest(request_type): assert response.operation.name == "operations/spam" +def test_purge_contexts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.purge_contexts] = mock_rpc + + request = {} + client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_purge_contexts_rest_required_fields( request_type=metadata_service.PurgeContextsRequest, ): @@ -17026,6 +20565,47 @@ def test_add_context_artifacts_and_executions_rest(request_type): ) +def test_add_context_artifacts_and_executions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_context_artifacts_and_executions + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_context_artifacts_and_executions + ] = mock_rpc + + request = {} + client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_context_artifacts_and_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_add_context_artifacts_and_executions_rest_required_fields( request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, ): @@ -17299,19 +20879,59 @@ def test_add_context_children_rest(request_type): # Designate an appropriate value for the returned response. return_value = metadata_service.AddContextChildrenResponse() - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - # Convert return value to protobuf type - return_value = metadata_service.AddContextChildrenResponse.pb(return_value) - json_return_value = json_format.MessageToJson(return_value) + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = metadata_service.AddContextChildrenResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.add_context_children(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_service.AddContextChildrenResponse) + + +def test_add_context_children_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_context_children in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_context_children + ] = mock_rpc + + request = {} + client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.add_context_children(request) + client.add_context_children(request) - # Establish that the response is the type that we expect. - assert isinstance(response, metadata_service.AddContextChildrenResponse) + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_add_context_children_rest_required_fields( @@ -17587,6 +21207,47 @@ def test_remove_context_children_rest(request_type): assert isinstance(response, metadata_service.RemoveContextChildrenResponse) +def test_remove_context_children_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.remove_context_children + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.remove_context_children + ] = mock_rpc + + request = {} + client.remove_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.remove_context_children(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_remove_context_children_rest_required_fields( request_type=metadata_service.RemoveContextChildrenRequest, ): @@ -17864,6 +21525,47 @@ def test_query_context_lineage_subgraph_rest(request_type): assert isinstance(response, lineage_subgraph.LineageSubgraph) +def test_query_context_lineage_subgraph_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_context_lineage_subgraph + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_context_lineage_subgraph + ] = mock_rpc + + request = {} + client.query_context_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_context_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_context_lineage_subgraph_rest_required_fields( request_type=metadata_service.QueryContextLineageSubgraphRequest, ): @@ -18232,6 +21934,44 @@ def get_message_fields(field): assert response.description == "description_value" +def test_create_execution_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_execution + ] = mock_rpc + + request = {} + client.create_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_execution_rest_required_fields( request_type=metadata_service.CreateExecutionRequest, ): @@ -18532,6 +22272,42 @@ def test_get_execution_rest(request_type): assert response.description == "description_value" +def test_get_execution_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_execution] = mock_rpc + + request = {} + client.get_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_execution_rest_required_fields( request_type=metadata_service.GetExecutionRequest, ): @@ -18803,6 +22579,42 @@ def test_list_executions_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_executions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_executions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_executions] = mock_rpc + + request = {} + client.list_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_executions_rest_required_fields( request_type=metadata_service.ListExecutionsRequest, ): @@ -19248,6 +23060,44 @@ def get_message_fields(field): assert response.description == "description_value" +def test_update_execution_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_execution + ] = mock_rpc + + request = {} + client.update_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_execution_rest_required_fields( request_type=metadata_service.UpdateExecutionRequest, ): @@ -19533,6 +23383,48 @@ def test_delete_execution_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_execution_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_execution in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_execution + ] = mock_rpc + + request = {} + client.delete_execution(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_execution(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_execution_rest_required_fields( request_type=metadata_service.DeleteExecutionRequest, ): @@ -19800,6 +23692,48 @@ def test_purge_executions_rest(request_type): assert response.operation.name == "operations/spam" +def test_purge_executions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.purge_executions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.purge_executions + ] = mock_rpc + + request = {} + client.purge_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.purge_executions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_purge_executions_rest_required_fields( request_type=metadata_service.PurgeExecutionsRequest, ): @@ -20080,6 +24014,46 @@ def test_add_execution_events_rest(request_type): assert isinstance(response, metadata_service.AddExecutionEventsResponse) +def test_add_execution_events_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_execution_events in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_execution_events + ] = mock_rpc + + request = {} + client.add_execution_events(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_execution_events(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_add_execution_events_rest_required_fields( request_type=metadata_service.AddExecutionEventsRequest, ): @@ -20353,6 +24327,47 @@ def test_query_execution_inputs_and_outputs_rest(request_type): assert isinstance(response, lineage_subgraph.LineageSubgraph) +def test_query_execution_inputs_and_outputs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_execution_inputs_and_outputs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_execution_inputs_and_outputs + ] = mock_rpc + + request = {} + client.query_execution_inputs_and_outputs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_execution_inputs_and_outputs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_execution_inputs_and_outputs_rest_required_fields( request_type=metadata_service.QueryExecutionInputsAndOutputsRequest, ): @@ -20721,6 +24736,47 @@ def get_message_fields(field): assert response.description == "description_value" +def test_create_metadata_schema_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_metadata_schema + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_metadata_schema + ] = mock_rpc + + request = {} + client.create_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_metadata_schema_rest_required_fields( request_type=metadata_service.CreateMetadataSchemaRequest, ): @@ -21020,6 +25076,46 @@ def test_get_metadata_schema_rest(request_type): assert response.description == "description_value" +def test_get_metadata_schema_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_metadata_schema in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_metadata_schema + ] = mock_rpc + + request = {} + client.get_metadata_schema(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_metadata_schema(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_metadata_schema_rest_required_fields( request_type=metadata_service.GetMetadataSchemaRequest, ): @@ -21293,6 +25389,47 @@ def test_list_metadata_schemas_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_metadata_schemas_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_metadata_schemas + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_metadata_schemas + ] = mock_rpc + + request = {} + client.list_metadata_schemas(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_metadata_schemas(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_metadata_schemas_rest_required_fields( request_type=metadata_service.ListMetadataSchemasRequest, ): @@ -21641,6 +25778,47 @@ def test_query_artifact_lineage_subgraph_rest(request_type): assert isinstance(response, lineage_subgraph.LineageSubgraph) +def test_query_artifact_lineage_subgraph_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_artifact_lineage_subgraph + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_artifact_lineage_subgraph + ] = mock_rpc + + request = {} + client.query_artifact_lineage_subgraph(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_artifact_lineage_subgraph(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_artifact_lineage_subgraph_rest_required_fields( request_type=metadata_service.QueryArtifactLineageSubgraphRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 824b17f1b1..a8d61b1fbd 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -1213,6 +1213,9 @@ def test_search_migratable_resources_empty_call(): with mock.patch.object( type(client.transport.search_migratable_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_migratable_resources() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1240,6 +1243,9 @@ def test_search_migratable_resources_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.search_migratable_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_migratable_resources(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1250,6 +1256,46 @@ def test_search_migratable_resources_non_empty_request_with_auto_populated_field ) +def test_search_migratable_resources_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_migratable_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_migratable_resources + ] = mock_rpc + request = {} + client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_migratable_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_migratable_resources_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1275,6 +1321,52 @@ async def test_search_migratable_resources_empty_call_async(): assert args[0] == migration_service.SearchMigratableResourcesRequest() +@pytest.mark.asyncio +async def test_search_migratable_resources_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MigrationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_migratable_resources + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_migratable_resources + ] = mock_object + + request = {} + await client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_migratable_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_migratable_resources_async( transport: str = "grpc_asyncio", @@ -1717,6 +1809,9 @@ def test_batch_migrate_resources_empty_call(): with mock.patch.object( type(client.transport.batch_migrate_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_migrate_resources() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1742,6 +1837,9 @@ def test_batch_migrate_resources_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.batch_migrate_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_migrate_resources(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1750,6 +1848,50 @@ def test_batch_migrate_resources_non_empty_request_with_auto_populated_field(): ) +def test_batch_migrate_resources_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_migrate_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_migrate_resources + ] = mock_rpc + request = {} + client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_migrate_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_migrate_resources_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1773,6 +1915,56 @@ async def test_batch_migrate_resources_empty_call_async(): assert args[0] == migration_service.BatchMigrateResourcesRequest() +@pytest.mark.asyncio +async def test_batch_migrate_resources_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = MigrationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_migrate_resources + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_migrate_resources + ] = mock_object + + request = {} + await client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_migrate_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_migrate_resources_async( transport: str = "grpc_asyncio", @@ -2051,6 +2243,47 @@ def test_search_migratable_resources_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_migratable_resources_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_migratable_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_migratable_resources + ] = mock_rpc + + request = {} + client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_migratable_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_migratable_resources_rest_required_fields( request_type=migration_service.SearchMigratableResourcesRequest, ): @@ -2381,6 +2614,51 @@ def test_batch_migrate_resources_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_migrate_resources_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = MigrationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_migrate_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_migrate_resources + ] = mock_rpc + + request = {} + client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_migrate_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_migrate_resources_rest_required_fields( request_type=migration_service.BatchMigrateResourcesRequest, ): @@ -3257,19 +3535,22 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - dataset = "mussel" - expected = "projects/{project}/datasets/{dataset}".format( + location = "mussel" + dataset = "winkle" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -3279,22 +3560,19 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "scallop" - location = "abalone" - dataset = "squid" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project = "squid" + dataset = "clam" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "clam", - "location": "whelk", + "project": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py index 1d3cb6fa55..fabcaba801 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py @@ -1242,6 +1242,9 @@ def test_get_publisher_model_empty_call(): with mock.patch.object( type(client.transport.get_publisher_model), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_publisher_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1268,6 +1271,9 @@ def test_get_publisher_model_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_publisher_model), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_publisher_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1277,6 +1283,45 @@ def test_get_publisher_model_non_empty_request_with_auto_populated_field(): ) +def test_get_publisher_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_publisher_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_publisher_model + ] = mock_rpc + request = {} + client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_publisher_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_publisher_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1308,6 +1353,52 @@ async def test_get_publisher_model_empty_call_async(): assert args[0] == model_garden_service.GetPublisherModelRequest() +@pytest.mark.asyncio +async def test_get_publisher_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_publisher_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_publisher_model + ] = mock_object + + request = {} + await client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_publisher_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_publisher_model_async( transport: str = "grpc_asyncio", @@ -1571,6 +1662,9 @@ def test_list_publisher_models_empty_call(): with mock.patch.object( type(client.transport.list_publisher_models), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_publisher_models() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1600,6 +1694,9 @@ def test_list_publisher_models_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_publisher_models), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_publisher_models(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1612,6 +1709,46 @@ def test_list_publisher_models_non_empty_request_with_auto_populated_field(): ) +def test_list_publisher_models_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_publisher_models + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_publisher_models + ] = mock_rpc + request = {} + client.list_publisher_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_publisher_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_publisher_models_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1637,6 +1774,52 @@ async def test_list_publisher_models_empty_call_async(): assert args[0] == model_garden_service.ListPublisherModelsRequest() +@pytest.mark.asyncio +async def test_list_publisher_models_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelGardenServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_publisher_models + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_publisher_models + ] = mock_object + + request = {} + await client.list_publisher_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_publisher_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_publisher_models_async( transport: str = "grpc_asyncio", @@ -2088,6 +2271,46 @@ def test_get_publisher_model_rest(request_type): assert response.publisher_model_template == "publisher_model_template_value" +def test_get_publisher_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_publisher_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_publisher_model + ] = mock_rpc + + request = {} + client.get_publisher_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_publisher_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_publisher_model_rest_required_fields( request_type=model_garden_service.GetPublisherModelRequest, ): @@ -2368,6 +2591,47 @@ def test_list_publisher_models_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_publisher_models_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelGardenServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_publisher_models + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_publisher_models + ] = mock_rpc + + request = {} + client.list_publisher_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_publisher_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_publisher_models_rest_required_fields( request_type=model_garden_service.ListPublisherModelsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_monitoring_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_monitoring_service.py index 0ec0c963a9..00075dc2c7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_monitoring_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_monitoring_service.py @@ -1274,6 +1274,9 @@ def test_create_model_monitor_empty_call(): with mock.patch.object( type(client.transport.create_model_monitor), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_model_monitor() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1300,6 +1303,9 @@ def test_create_model_monitor_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_model_monitor), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_model_monitor(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1309,6 +1315,49 @@ def test_create_model_monitor_non_empty_request_with_auto_populated_field(): ) +def test_create_model_monitor_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_model_monitor in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_model_monitor + ] = mock_rpc + request = {} + client.create_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_model_monitor_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1332,6 +1381,56 @@ async def test_create_model_monitor_empty_call_async(): assert args[0] == model_monitoring_service.CreateModelMonitorRequest() +@pytest.mark.asyncio +async def test_create_model_monitor_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_model_monitor + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_model_monitor + ] = mock_object + + request = {} + await client.create_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_model_monitor_async( transport: str = "grpc_asyncio", @@ -1615,6 +1714,9 @@ def test_update_model_monitor_empty_call(): with mock.patch.object( type(client.transport.update_model_monitor), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model_monitor() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1638,12 +1740,58 @@ def test_update_model_monitor_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_model_monitor), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model_monitor(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_monitoring_service.UpdateModelMonitorRequest() +def test_update_model_monitor_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_model_monitor in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_model_monitor + ] = mock_rpc + request = {} + client.update_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_monitor_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1667,6 +1815,56 @@ async def test_update_model_monitor_empty_call_async(): assert args[0] == model_monitoring_service.UpdateModelMonitorRequest() +@pytest.mark.asyncio +async def test_update_model_monitor_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_model_monitor + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_model_monitor + ] = mock_object + + request = {} + await client.update_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_monitor_async( transport: str = "grpc_asyncio", @@ -1955,6 +2153,9 @@ def test_get_model_monitor_empty_call(): with mock.patch.object( type(client.transport.get_model_monitor), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_monitor() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1980,6 +2181,9 @@ def test_get_model_monitor_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_model_monitor), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_monitor(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1988,6 +2192,43 @@ def test_get_model_monitor_non_empty_request_with_auto_populated_field(): ) +def test_get_model_monitor_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_model_monitor in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_monitor + ] = mock_rpc + request = {} + client.get_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_monitor_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2014,6 +2255,52 @@ async def test_get_model_monitor_empty_call_async(): assert args[0] == model_monitoring_service.GetModelMonitorRequest() +@pytest.mark.asyncio +async def test_get_model_monitor_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model_monitor + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model_monitor + ] = mock_object + + request = {} + await client.get_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_monitor_async( transport: str = "grpc_asyncio", @@ -2259,6 +2546,9 @@ def test_list_model_monitors_empty_call(): with mock.patch.object( type(client.transport.list_model_monitors), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_monitors() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2286,6 +2576,9 @@ def test_list_model_monitors_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_model_monitors), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_monitors(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2296,6 +2589,45 @@ def test_list_model_monitors_non_empty_request_with_auto_populated_field(): ) +def test_list_model_monitors_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_monitors in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_monitors + ] = mock_rpc + request = {} + client.list_model_monitors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_monitors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_monitors_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2321,6 +2653,52 @@ async def test_list_model_monitors_empty_call_async(): assert args[0] == model_monitoring_service.ListModelMonitorsRequest() +@pytest.mark.asyncio +async def test_list_model_monitors_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_monitors + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_monitors + ] = mock_object + + request = {} + await client.list_model_monitors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_monitors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_monitors_async( transport: str = "grpc_asyncio", @@ -2759,6 +3137,9 @@ def test_delete_model_monitor_empty_call(): with mock.patch.object( type(client.transport.delete_model_monitor), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_monitor() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2784,6 +3165,9 @@ def test_delete_model_monitor_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_model_monitor), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_monitor(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2792,6 +3176,49 @@ def test_delete_model_monitor_non_empty_request_with_auto_populated_field(): ) +def test_delete_model_monitor_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_monitor in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_monitor + ] = mock_rpc + request = {} + client.delete_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_monitor_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2815,6 +3242,56 @@ async def test_delete_model_monitor_empty_call_async(): assert args[0] == model_monitoring_service.DeleteModelMonitorRequest() +@pytest.mark.asyncio +async def test_delete_model_monitor_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_model_monitor + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_model_monitor + ] = mock_object + + request = {} + await client.delete_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_monitor_async( transport: str = "grpc_asyncio", @@ -3061,6 +3538,9 @@ def test_create_model_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.create_model_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_model_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3087,6 +3567,9 @@ def test_create_model_monitoring_job_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.create_model_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_model_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3096,6 +3579,46 @@ def test_create_model_monitoring_job_non_empty_request_with_auto_populated_field ) +def test_create_model_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_model_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_model_monitoring_job + ] = mock_rpc + request = {} + client.create_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_model_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3124,6 +3647,52 @@ async def test_create_model_monitoring_job_empty_call_async(): assert args[0] == model_monitoring_service.CreateModelMonitoringJobRequest() +@pytest.mark.asyncio +async def test_create_model_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_model_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_model_monitoring_job + ] = mock_object + + request = {} + await client.create_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_model_monitoring_job_async( transport: str = "grpc_asyncio", @@ -3397,6 +3966,9 @@ def test_get_model_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.get_model_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3422,6 +3994,9 @@ def test_get_model_monitoring_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_model_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3430,6 +4005,46 @@ def test_get_model_monitoring_job_non_empty_request_with_auto_populated_field(): ) +def test_get_model_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_monitoring_job + ] = mock_rpc + request = {} + client.get_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3458,6 +4073,52 @@ async def test_get_model_monitoring_job_empty_call_async(): assert args[0] == model_monitoring_service.GetModelMonitoringJobRequest() +@pytest.mark.asyncio +async def test_get_model_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model_monitoring_job + ] = mock_object + + request = {} + await client.get_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_monitoring_job_async( transport: str = "grpc_asyncio", @@ -3707,6 +4368,9 @@ def test_list_model_monitoring_jobs_empty_call(): with mock.patch.object( type(client.transport.list_model_monitoring_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_monitoring_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3734,6 +4398,9 @@ def test_list_model_monitoring_jobs_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.list_model_monitoring_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_monitoring_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3744,6 +4411,46 @@ def test_list_model_monitoring_jobs_non_empty_request_with_auto_populated_field( ) +def test_list_model_monitoring_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_monitoring_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_monitoring_jobs + ] = mock_rpc + request = {} + client.list_model_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_monitoring_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3769,6 +4476,52 @@ async def test_list_model_monitoring_jobs_empty_call_async(): assert args[0] == model_monitoring_service.ListModelMonitoringJobsRequest() +@pytest.mark.asyncio +async def test_list_model_monitoring_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_monitoring_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_monitoring_jobs + ] = mock_object + + request = {} + await client.list_model_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_monitoring_jobs_async( transport: str = "grpc_asyncio", @@ -4211,6 +4964,9 @@ def test_delete_model_monitoring_job_empty_call(): with mock.patch.object( type(client.transport.delete_model_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_monitoring_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4236,6 +4992,9 @@ def test_delete_model_monitoring_job_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.delete_model_monitoring_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_monitoring_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4244,6 +5003,50 @@ def test_delete_model_monitoring_job_non_empty_request_with_auto_populated_field ) +def test_delete_model_monitoring_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_monitoring_job + ] = mock_rpc + request = {} + client.delete_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_monitoring_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4267,6 +5070,56 @@ async def test_delete_model_monitoring_job_empty_call_async(): assert args[0] == model_monitoring_service.DeleteModelMonitoringJobRequest() +@pytest.mark.asyncio +async def test_delete_model_monitoring_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_model_monitoring_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_model_monitoring_job + ] = mock_object + + request = {} + await client.delete_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_monitoring_job_async( transport: str = "grpc_asyncio", @@ -4507,6 +5360,9 @@ def test_search_model_monitoring_stats_empty_call(): with mock.patch.object( type(client.transport.search_model_monitoring_stats), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_model_monitoring_stats() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4533,6 +5389,9 @@ def test_search_model_monitoring_stats_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.search_model_monitoring_stats), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_model_monitoring_stats(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4542,6 +5401,46 @@ def test_search_model_monitoring_stats_non_empty_request_with_auto_populated_fie ) +def test_search_model_monitoring_stats_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_model_monitoring_stats + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_model_monitoring_stats + ] = mock_rpc + request = {} + client.search_model_monitoring_stats(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_model_monitoring_stats(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_model_monitoring_stats_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4561,10 +5460,56 @@ async def test_search_model_monitoring_stats_empty_call_async(): next_page_token="next_page_token_value", ) ) - response = await client.search_model_monitoring_stats() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == model_monitoring_service.SearchModelMonitoringStatsRequest() + response = await client.search_model_monitoring_stats() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_monitoring_service.SearchModelMonitoringStatsRequest() + + +@pytest.mark.asyncio +async def test_search_model_monitoring_stats_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_model_monitoring_stats + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_model_monitoring_stats + ] = mock_object + + request = {} + await client.search_model_monitoring_stats(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_model_monitoring_stats(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio @@ -5023,6 +5968,9 @@ def test_search_model_monitoring_alerts_empty_call(): with mock.patch.object( type(client.transport.search_model_monitoring_alerts), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_model_monitoring_alerts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5052,6 +6000,9 @@ def test_search_model_monitoring_alerts_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.search_model_monitoring_alerts), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.search_model_monitoring_alerts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5064,6 +6015,46 @@ def test_search_model_monitoring_alerts_non_empty_request_with_auto_populated_fi ) +def test_search_model_monitoring_alerts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_model_monitoring_alerts + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_model_monitoring_alerts + ] = mock_rpc + request = {} + client.search_model_monitoring_alerts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_model_monitoring_alerts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_search_model_monitoring_alerts_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5090,6 +6081,52 @@ async def test_search_model_monitoring_alerts_empty_call_async(): assert args[0] == model_monitoring_service.SearchModelMonitoringAlertsRequest() +@pytest.mark.asyncio +async def test_search_model_monitoring_alerts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.search_model_monitoring_alerts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.search_model_monitoring_alerts + ] = mock_object + + request = {} + await client.search_model_monitoring_alerts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.search_model_monitoring_alerts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_search_model_monitoring_alerts_async( transport: str = "grpc_asyncio", @@ -5723,6 +6760,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_model_monitor_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_model_monitor in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_model_monitor + ] = mock_rpc + + request = {} + client.create_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_model_monitor_rest_required_fields( request_type=model_monitoring_service.CreateModelMonitorRequest, ): @@ -6206,6 +7287,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_model_monitor_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_model_monitor in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_model_monitor + ] = mock_rpc + + request = {} + client.update_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_model_monitor_rest_required_fields( request_type=model_monitoring_service.UpdateModelMonitorRequest, ): @@ -6501,6 +7626,44 @@ def test_get_model_monitor_rest(request_type): assert response.display_name == "display_name_value" +def test_get_model_monitor_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_model_monitor in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_monitor + ] = mock_rpc + + request = {} + client.get_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_monitor_rest_required_fields( request_type=model_monitoring_service.GetModelMonitorRequest, ): @@ -6773,6 +7936,46 @@ def test_list_model_monitors_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_model_monitors_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_monitors in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_monitors + ] = mock_rpc + + request = {} + client.list_model_monitors(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_monitors(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_model_monitors_rest_required_fields( request_type=model_monitoring_service.ListModelMonitorsRequest, ): @@ -7119,6 +8322,50 @@ def test_delete_model_monitor_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_model_monitor_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_monitor in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_monitor + ] = mock_rpc + + request = {} + client.delete_model_monitor(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_monitor(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_model_monitor_rest_required_fields( request_type=model_monitoring_service.DeleteModelMonitorRequest, ): @@ -7604,6 +8851,47 @@ def get_message_fields(field): assert response.schedule == "schedule_value" +def test_create_model_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_model_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_model_monitoring_job + ] = mock_rpc + + request = {} + client.create_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_model_monitoring_job_rest_required_fields( request_type=model_monitoring_service.CreateModelMonitoringJobRequest, ): @@ -7903,6 +9191,47 @@ def test_get_model_monitoring_job_rest(request_type): assert response.schedule == "schedule_value" +def test_get_model_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_monitoring_job + ] = mock_rpc + + request = {} + client.get_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_monitoring_job_rest_required_fields( request_type=model_monitoring_service.GetModelMonitoringJobRequest, ): @@ -8180,6 +9509,47 @@ def test_list_model_monitoring_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_model_monitoring_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_monitoring_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_monitoring_jobs + ] = mock_rpc + + request = {} + client.list_model_monitoring_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_monitoring_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_model_monitoring_jobs_rest_required_fields( request_type=model_monitoring_service.ListModelMonitoringJobsRequest, ): @@ -8538,6 +9908,51 @@ def test_delete_model_monitoring_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_model_monitoring_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_monitoring_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_monitoring_job + ] = mock_rpc + + request = {} + client.delete_model_monitoring_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_monitoring_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_model_monitoring_job_rest_required_fields( request_type=model_monitoring_service.DeleteModelMonitoringJobRequest, ): @@ -8813,6 +10228,47 @@ def test_search_model_monitoring_stats_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_model_monitoring_stats_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_model_monitoring_stats + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_model_monitoring_stats + ] = mock_rpc + + request = {} + client.search_model_monitoring_stats(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_model_monitoring_stats(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_model_monitoring_stats_rest_required_fields( request_type=model_monitoring_service.SearchModelMonitoringStatsRequest, ): @@ -9168,6 +10624,47 @@ def test_search_model_monitoring_alerts_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_search_model_monitoring_alerts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelMonitoringServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.search_model_monitoring_alerts + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.search_model_monitoring_alerts + ] = mock_rpc + + request = {} + client.search_model_monitoring_alerts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.search_model_monitoring_alerts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_search_model_monitoring_alerts_rest_required_fields( request_type=model_monitoring_service.SearchModelMonitoringAlertsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index 9399935d71..3796525e89 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -1171,6 +1171,9 @@ def test_upload_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upload_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1197,6 +1200,9 @@ def test_upload_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upload_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1208,6 +1214,45 @@ def test_upload_model_non_empty_request_with_auto_populated_field(): ) +def test_upload_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upload_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.upload_model] = mock_rpc + request = {} + client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.upload_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_upload_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1229,6 +1274,56 @@ async def test_upload_model_empty_call_async(): assert args[0] == model_service.UploadModelRequest() +@pytest.mark.asyncio +async def test_upload_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.upload_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.upload_model + ] = mock_object + + request = {} + await client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.upload_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_upload_model_async( transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest @@ -1498,6 +1593,9 @@ def test_get_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1521,6 +1619,9 @@ def test_get_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1529,6 +1630,41 @@ def test_get_model_non_empty_request_with_auto_populated_field(): ) +def test_get_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_model] = mock_rpc + request = {} + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1571,6 +1707,50 @@ async def test_get_model_empty_call_async(): assert args[0] == model_service.GetModelRequest() +@pytest.mark.asyncio +async def test_get_model_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model + ] = mock_object + + request = {} + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_async( transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest @@ -1833,6 +2013,9 @@ def test_list_models_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_models() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1858,6 +2041,9 @@ def test_list_models_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_models(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1868,6 +2054,41 @@ def test_list_models_non_empty_request_with_auto_populated_field(): ) +def test_list_models_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_models in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_models] = mock_rpc + request = {} + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_models_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1891,6 +2112,52 @@ async def test_list_models_empty_call_async(): assert args[0] == model_service.ListModelsRequest() +@pytest.mark.asyncio +async def test_list_models_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_models + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_models + ] = mock_object + + request = {} + await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_models_async( transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest @@ -2313,6 +2580,9 @@ def test_list_model_versions_empty_call(): with mock.patch.object( type(client.transport.list_model_versions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_versions() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2341,6 +2611,9 @@ def test_list_model_versions_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_model_versions), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_versions(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2352,6 +2625,45 @@ def test_list_model_versions_non_empty_request_with_auto_populated_field(): ) +def test_list_model_versions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_versions in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_versions + ] = mock_rpc + request = {} + client.list_model_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_versions_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2377,6 +2689,52 @@ async def test_list_model_versions_empty_call_async(): assert args[0] == model_service.ListModelVersionsRequest() +@pytest.mark.asyncio +async def test_list_model_versions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_versions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_versions + ] = mock_object + + request = {} + await client.list_model_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_versions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_versions_async( transport: str = "grpc_asyncio", request_type=model_service.ListModelVersionsRequest @@ -2847,6 +3205,9 @@ def test_update_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2868,12 +3229,50 @@ def test_update_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UpdateModelRequest() +def test_update_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_model] = mock_rpc + request = {} + client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2916,6 +3315,52 @@ async def test_update_model_empty_call_async(): assert args[0] == model_service.UpdateModelRequest() +@pytest.mark.asyncio +async def test_update_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_model + ] = mock_object + + request = {} + await client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_model_async( transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest @@ -3189,6 +3634,9 @@ def test_update_explanation_dataset_empty_call(): with mock.patch.object( type(client.transport.update_explanation_dataset), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_explanation_dataset() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3214,6 +3662,9 @@ def test_update_explanation_dataset_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.update_explanation_dataset), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_explanation_dataset(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3222,6 +3673,50 @@ def test_update_explanation_dataset_non_empty_request_with_auto_populated_field( ) +def test_update_explanation_dataset_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_explanation_dataset + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_explanation_dataset + ] = mock_rpc + request = {} + client.update_explanation_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_explanation_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_explanation_dataset_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3245,6 +3740,56 @@ async def test_update_explanation_dataset_empty_call_async(): assert args[0] == model_service.UpdateExplanationDatasetRequest() +@pytest.mark.asyncio +async def test_update_explanation_dataset_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_explanation_dataset + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_explanation_dataset + ] = mock_object + + request = {} + await client.update_explanation_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_explanation_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_explanation_dataset_async( transport: str = "grpc_asyncio", @@ -3478,6 +4023,9 @@ def test_delete_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3501,6 +4049,9 @@ def test_delete_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3509,6 +4060,45 @@ def test_delete_model_non_empty_request_with_auto_populated_field(): ) +def test_delete_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_model] = mock_rpc + request = {} + client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3531,10 +4121,60 @@ async def test_delete_model_empty_call_async(): @pytest.mark.asyncio -async def test_delete_model_async( - transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +async def test_delete_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = ModelServiceAsyncClient( + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_model + ] = mock_object + + request = {} + await client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +): + client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) @@ -3756,6 +4396,9 @@ def test_delete_model_version_empty_call(): with mock.patch.object( type(client.transport.delete_model_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_version() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3781,6 +4424,9 @@ def test_delete_model_version_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_model_version), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_model_version(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3789,6 +4435,49 @@ def test_delete_model_version_non_empty_request_with_auto_populated_field(): ) +def test_delete_model_version_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_version in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_version + ] = mock_rpc + request = {} + client.delete_model_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_version_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3812,6 +4501,56 @@ async def test_delete_model_version_empty_call_async(): assert args[0] == model_service.DeleteModelVersionRequest() +@pytest.mark.asyncio +async def test_delete_model_version_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_model_version + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_model_version + ] = mock_object + + request = {} + await client.delete_model_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_model_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_model_version_async( transport: str = "grpc_asyncio", @@ -4086,6 +4825,9 @@ def test_merge_version_aliases_empty_call(): with mock.patch.object( type(client.transport.merge_version_aliases), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.merge_version_aliases() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4111,6 +4853,9 @@ def test_merge_version_aliases_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.merge_version_aliases), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.merge_version_aliases(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4119,6 +4864,46 @@ def test_merge_version_aliases_non_empty_request_with_auto_populated_field(): ) +def test_merge_version_aliases_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.merge_version_aliases + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.merge_version_aliases + ] = mock_rpc + request = {} + client.merge_version_aliases(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.merge_version_aliases(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_merge_version_aliases_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4163,6 +4948,52 @@ async def test_merge_version_aliases_empty_call_async(): assert args[0] == model_service.MergeVersionAliasesRequest() +@pytest.mark.asyncio +async def test_merge_version_aliases_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.merge_version_aliases + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.merge_version_aliases + ] = mock_object + + request = {} + await client.merge_version_aliases(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.merge_version_aliases(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_merge_version_aliases_async( transport: str = "grpc_asyncio", @@ -4443,6 +5274,9 @@ def test_export_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4466,6 +5300,9 @@ def test_export_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4474,6 +5311,45 @@ def test_export_model_non_empty_request_with_auto_populated_field(): ) +def test_export_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_model] = mock_rpc + request = {} + client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4495,6 +5371,56 @@ async def test_export_model_empty_call_async(): assert args[0] == model_service.ExportModelRequest() +@pytest.mark.asyncio +async def test_export_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_model + ] = mock_object + + request = {} + await client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.export_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_export_model_async( transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest @@ -4739,6 +5665,9 @@ def test_copy_model_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.copy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.copy_model() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4765,6 +5694,9 @@ def test_copy_model_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.copy_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.copy_model(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4776,6 +5708,45 @@ def test_copy_model_non_empty_request_with_auto_populated_field(): ) +def test_copy_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.copy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.copy_model] = mock_rpc + request = {} + client.copy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.copy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_copy_model_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4797,6 +5768,54 @@ async def test_copy_model_empty_call_async(): assert args[0] == model_service.CopyModelRequest() +@pytest.mark.asyncio +async def test_copy_model_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.copy_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.copy_model + ] = mock_object + + request = {} + await client.copy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.copy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_copy_model_async( transport: str = "grpc_asyncio", request_type=model_service.CopyModelRequest @@ -5042,6 +6061,9 @@ def test_import_model_evaluation_empty_call(): with mock.patch.object( type(client.transport.import_model_evaluation), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_model_evaluation() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5067,6 +6089,9 @@ def test_import_model_evaluation_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.import_model_evaluation), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_model_evaluation(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5075,6 +6100,46 @@ def test_import_model_evaluation_non_empty_request_with_auto_populated_field(): ) +def test_import_model_evaluation_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.import_model_evaluation + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_model_evaluation + ] = mock_rpc + request = {} + client.import_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.import_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_model_evaluation_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5103,6 +6168,52 @@ async def test_import_model_evaluation_empty_call_async(): assert args[0] == model_service.ImportModelEvaluationRequest() +@pytest.mark.asyncio +async def test_import_model_evaluation_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_model_evaluation + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_model_evaluation + ] = mock_object + + request = {} + await client.import_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.import_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_model_evaluation_async( transport: str = "grpc_asyncio", @@ -5364,6 +6475,9 @@ def test_batch_import_model_evaluation_slices_empty_call(): with mock.patch.object( type(client.transport.batch_import_model_evaluation_slices), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_import_model_evaluation_slices() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5389,6 +6503,9 @@ def test_batch_import_model_evaluation_slices_non_empty_request_with_auto_popula with mock.patch.object( type(client.transport.batch_import_model_evaluation_slices), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_import_model_evaluation_slices(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5397,6 +6514,46 @@ def test_batch_import_model_evaluation_slices_non_empty_request_with_auto_popula ) +def test_batch_import_model_evaluation_slices_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_import_model_evaluation_slices + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_import_model_evaluation_slices + ] = mock_rpc + request = {} + client.batch_import_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_import_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_import_model_evaluation_slices_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5424,6 +6581,52 @@ async def test_batch_import_model_evaluation_slices_empty_call_async(): assert args[0] == model_service.BatchImportModelEvaluationSlicesRequest() +@pytest.mark.asyncio +async def test_batch_import_model_evaluation_slices_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_import_model_evaluation_slices + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_import_model_evaluation_slices + ] = mock_object + + request = {} + await client.batch_import_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_import_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_import_model_evaluation_slices_async( transport: str = "grpc_asyncio", @@ -5689,6 +6892,9 @@ def test_batch_import_evaluated_annotations_empty_call(): with mock.patch.object( type(client.transport.batch_import_evaluated_annotations), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_import_evaluated_annotations() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5714,6 +6920,9 @@ def test_batch_import_evaluated_annotations_non_empty_request_with_auto_populate with mock.patch.object( type(client.transport.batch_import_evaluated_annotations), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_import_evaluated_annotations(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5722,6 +6931,46 @@ def test_batch_import_evaluated_annotations_non_empty_request_with_auto_populate ) +def test_batch_import_evaluated_annotations_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_import_evaluated_annotations + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_import_evaluated_annotations + ] = mock_rpc + request = {} + client.batch_import_evaluated_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_import_evaluated_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_import_evaluated_annotations_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5747,6 +6996,52 @@ async def test_batch_import_evaluated_annotations_empty_call_async(): assert args[0] == model_service.BatchImportEvaluatedAnnotationsRequest() +@pytest.mark.asyncio +async def test_batch_import_evaluated_annotations_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_import_evaluated_annotations + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_import_evaluated_annotations + ] = mock_object + + request = {} + await client.batch_import_evaluated_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_import_evaluated_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_import_evaluated_annotations_async( transport: str = "grpc_asyncio", @@ -6030,6 +7325,9 @@ def test_get_model_evaluation_empty_call(): with mock.patch.object( type(client.transport.get_model_evaluation), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_evaluation() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6055,6 +7353,9 @@ def test_get_model_evaluation_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_model_evaluation), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_evaluation(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6063,6 +7364,45 @@ def test_get_model_evaluation_non_empty_request_with_auto_populated_field(): ) +def test_get_model_evaluation_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_evaluation in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_evaluation + ] = mock_rpc + request = {} + client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_evaluation_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6091,6 +7431,52 @@ async def test_get_model_evaluation_empty_call_async(): assert args[0] == model_service.GetModelEvaluationRequest() +@pytest.mark.asyncio +async def test_get_model_evaluation_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model_evaluation + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model_evaluation + ] = mock_object + + request = {} + await client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_evaluation_async( transport: str = "grpc_asyncio", @@ -6340,6 +7726,9 @@ def test_list_model_evaluations_empty_call(): with mock.patch.object( type(client.transport.list_model_evaluations), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_evaluations() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6367,6 +7756,9 @@ def test_list_model_evaluations_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_model_evaluations), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_evaluations(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6377,6 +7769,46 @@ def test_list_model_evaluations_non_empty_request_with_auto_populated_field(): ) +def test_list_model_evaluations_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_evaluations + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_evaluations + ] = mock_rpc + request = {} + client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_evaluations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_evaluations_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6402,6 +7834,52 @@ async def test_list_model_evaluations_empty_call_async(): assert args[0] == model_service.ListModelEvaluationsRequest() +@pytest.mark.asyncio +async def test_list_model_evaluations_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_evaluations + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_evaluations + ] = mock_object + + request = {} + await client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_evaluations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_evaluations_async( transport: str = "grpc_asyncio", @@ -6845,6 +8323,9 @@ def test_get_model_evaluation_slice_empty_call(): with mock.patch.object( type(client.transport.get_model_evaluation_slice), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_evaluation_slice() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6870,6 +8351,9 @@ def test_get_model_evaluation_slice_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.get_model_evaluation_slice), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_model_evaluation_slice(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6878,6 +8362,46 @@ def test_get_model_evaluation_slice_non_empty_request_with_auto_populated_field( ) +def test_get_model_evaluation_slice_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_evaluation_slice + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_evaluation_slice + ] = mock_rpc + request = {} + client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_evaluation_slice(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_evaluation_slice_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6904,6 +8428,52 @@ async def test_get_model_evaluation_slice_empty_call_async(): assert args[0] == model_service.GetModelEvaluationSliceRequest() +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model_evaluation_slice + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model_evaluation_slice + ] = mock_object + + request = {} + await client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_model_evaluation_slice(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_model_evaluation_slice_async( transport: str = "grpc_asyncio", @@ -7149,6 +8719,9 @@ def test_list_model_evaluation_slices_empty_call(): with mock.patch.object( type(client.transport.list_model_evaluation_slices), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_evaluation_slices() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7176,6 +8749,9 @@ def test_list_model_evaluation_slices_non_empty_request_with_auto_populated_fiel with mock.patch.object( type(client.transport.list_model_evaluation_slices), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_model_evaluation_slices(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7186,6 +8762,46 @@ def test_list_model_evaluation_slices_non_empty_request_with_auto_populated_fiel ) +def test_list_model_evaluation_slices_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_evaluation_slices + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_evaluation_slices + ] = mock_rpc + request = {} + client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_evaluation_slices_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7211,6 +8827,52 @@ async def test_list_model_evaluation_slices_empty_call_async(): assert args[0] == model_service.ListModelEvaluationSlicesRequest() +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_model_evaluation_slices + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_model_evaluation_slices + ] = mock_object + + request = {} + await client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async( transport: str = "grpc_asyncio", @@ -7642,6 +9304,46 @@ def test_upload_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_upload_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upload_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.upload_model] = mock_rpc + + request = {} + client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.upload_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_upload_model_rest_required_fields( request_type=model_service.UploadModelRequest, ): @@ -7951,6 +9653,42 @@ def test_get_model_rest(request_type): assert response.metadata_artifact == "metadata_artifact_value" +def test_get_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_model] = mock_rpc + + request = {} + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_rest_required_fields(request_type=model_service.GetModelRequest): transport_class = transports.ModelServiceRestTransport @@ -8212,6 +9950,42 @@ def test_list_models_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_models_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_models in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_models] = mock_rpc + + request = {} + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_models_rest_required_fields(request_type=model_service.ListModelsRequest): transport_class = transports.ModelServiceRestTransport @@ -8542,13 +10316,53 @@ def test_list_model_versions_rest(request_type): return_value = model_service.ListModelVersionsResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.list_model_versions(request) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_model_versions(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelVersionsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_model_versions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_versions in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_versions + ] = mock_rpc + + request = {} + client.list_model_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_versions(request) - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelVersionsPager) - assert response.next_page_token == "next_page_token_value" + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 def test_list_model_versions_rest_required_fields( @@ -9110,6 +10924,42 @@ def get_message_fields(field): assert response.metadata_artifact == "metadata_artifact_value" +def test_update_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_model] = mock_rpc + + request = {} + client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_model_rest_required_fields( request_type=model_service.UpdateModelRequest, ): @@ -9382,6 +11232,51 @@ def test_update_explanation_dataset_rest(request_type): assert response.operation.name == "operations/spam" +def test_update_explanation_dataset_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_explanation_dataset + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_explanation_dataset + ] = mock_rpc + + request = {} + client.update_explanation_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_explanation_dataset(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_explanation_dataset_rest_required_fields( request_type=model_service.UpdateExplanationDatasetRequest, ): @@ -9642,6 +11537,46 @@ def test_delete_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_model] = mock_rpc + + request = {} + client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_model_rest_required_fields( request_type=model_service.DeleteModelRequest, ): @@ -9901,6 +11836,50 @@ def test_delete_model_version_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_model_version_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_model_version in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_model_version + ] = mock_rpc + + request = {} + client.delete_model_version(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_model_version(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_model_version_rest_required_fields( request_type=model_service.DeleteModelVersionRequest, ): @@ -10199,6 +12178,47 @@ def test_merge_version_aliases_rest(request_type): assert response.metadata_artifact == "metadata_artifact_value" +def test_merge_version_aliases_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.merge_version_aliases + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.merge_version_aliases + ] = mock_rpc + + request = {} + client.merge_version_aliases(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.merge_version_aliases(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_merge_version_aliases_rest_required_fields( request_type=model_service.MergeVersionAliasesRequest, ): @@ -10474,6 +12494,46 @@ def test_export_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_export_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.export_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.export_model] = mock_rpc + + request = {} + client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.export_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_export_model_rest_required_fields( request_type=model_service.ExportModelRequest, ): @@ -10748,6 +12808,46 @@ def test_copy_model_rest(request_type): assert response.operation.name == "operations/spam" +def test_copy_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.copy_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.copy_model] = mock_rpc + + request = {} + client.copy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.copy_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_copy_model_rest_required_fields(request_type=model_service.CopyModelRequest): transport_class = transports.ModelServiceRestTransport @@ -11029,6 +13129,47 @@ def test_import_model_evaluation_rest(request_type): assert response.slice_dimensions == ["slice_dimensions_value"] +def test_import_model_evaluation_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.import_model_evaluation + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_model_evaluation + ] = mock_rpc + + request = {} + client.import_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.import_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_import_model_evaluation_rest_required_fields( request_type=model_service.ImportModelEvaluationRequest, ): @@ -11313,6 +13454,47 @@ def test_batch_import_model_evaluation_slices_rest(request_type): ] +def test_batch_import_model_evaluation_slices_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_import_model_evaluation_slices + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_import_model_evaluation_slices + ] = mock_rpc + + request = {} + client.batch_import_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_import_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_import_model_evaluation_slices_rest_required_fields( request_type=model_service.BatchImportModelEvaluationSlicesRequest, ): @@ -11616,6 +13798,47 @@ def test_batch_import_evaluated_annotations_rest(request_type): assert response.imported_evaluated_annotations_count == 3859 +def test_batch_import_evaluated_annotations_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_import_evaluated_annotations + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_import_evaluated_annotations + ] = mock_rpc + + request = {} + client.batch_import_evaluated_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_import_evaluated_annotations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_import_evaluated_annotations_rest_required_fields( request_type=model_service.BatchImportEvaluatedAnnotationsRequest, ): @@ -11926,6 +14149,46 @@ def test_get_model_evaluation_rest(request_type): assert response.slice_dimensions == ["slice_dimensions_value"] +def test_get_model_evaluation_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_evaluation in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_evaluation + ] = mock_rpc + + request = {} + client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_evaluation(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_evaluation_rest_required_fields( request_type=model_service.GetModelEvaluationRequest, ): @@ -12197,6 +14460,47 @@ def test_list_model_evaluations_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_model_evaluations_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_evaluations + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_evaluations + ] = mock_rpc + + request = {} + client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_evaluations(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_model_evaluations_rest_required_fields( request_type=model_service.ListModelEvaluationsRequest, ): @@ -12544,6 +14848,47 @@ def test_get_model_evaluation_slice_rest(request_type): assert response.metrics_schema_uri == "metrics_schema_uri_value" +def test_get_model_evaluation_slice_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_model_evaluation_slice + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_model_evaluation_slice + ] = mock_rpc + + request = {} + client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model_evaluation_slice(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_model_evaluation_slice_rest_required_fields( request_type=model_service.GetModelEvaluationSliceRequest, ): @@ -12817,6 +15162,47 @@ def test_list_model_evaluation_slices_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_model_evaluation_slices_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_model_evaluation_slices + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_model_evaluation_slices + ] = mock_rpc + + request = {} + client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_model_evaluation_slices(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_model_evaluation_slices_rest_required_fields( request_type=model_service.ListModelEvaluationSlicesRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_notebook_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_notebook_service.py index 7a8c62184e..9292848539 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_notebook_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_notebook_service.py @@ -58,9 +58,11 @@ from google.cloud.aiplatform_v1beta1.services.notebook_service import pagers from google.cloud.aiplatform_v1beta1.services.notebook_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type +from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import network_spec from google.cloud.aiplatform_v1beta1.types import notebook_euc_config +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_idle_shutdown_config from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import ( @@ -79,6 +81,7 @@ from google.protobuf import empty_pb2 # type: ignore from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore import google.auth @@ -1218,6 +1221,9 @@ def test_create_notebook_runtime_template_empty_call(): with mock.patch.object( type(client.transport.create_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_notebook_runtime_template() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1244,6 +1250,9 @@ def test_create_notebook_runtime_template_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.create_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_notebook_runtime_template(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1253,6 +1262,50 @@ def test_create_notebook_runtime_template_non_empty_request_with_auto_populated_ ) +def test_create_notebook_runtime_template_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_notebook_runtime_template + ] = mock_rpc + request = {} + client.create_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_notebook_runtime_template_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1276,6 +1329,56 @@ async def test_create_notebook_runtime_template_empty_call_async(): assert args[0] == notebook_service.CreateNotebookRuntimeTemplateRequest() +@pytest.mark.asyncio +async def test_create_notebook_runtime_template_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_notebook_runtime_template + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_notebook_runtime_template + ] = mock_object + + request = {} + await client.create_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_notebook_runtime_template_async( transport: str = "grpc_asyncio", @@ -1561,6 +1664,9 @@ def test_get_notebook_runtime_template_empty_call(): with mock.patch.object( type(client.transport.get_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_notebook_runtime_template() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1586,6 +1692,9 @@ def test_get_notebook_runtime_template_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.get_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_notebook_runtime_template(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1594,6 +1703,46 @@ def test_get_notebook_runtime_template_non_empty_request_with_auto_populated_fie ) +def test_get_notebook_runtime_template_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_runtime_template + ] = mock_rpc + request = {} + client.get_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_notebook_runtime_template_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1626,6 +1775,52 @@ async def test_get_notebook_runtime_template_empty_call_async(): assert args[0] == notebook_service.GetNotebookRuntimeTemplateRequest() +@pytest.mark.asyncio +async def test_get_notebook_runtime_template_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_notebook_runtime_template + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_notebook_runtime_template + ] = mock_object + + request = {} + await client.get_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_notebook_runtime_template_async( transport: str = "grpc_asyncio", @@ -1886,6 +2081,9 @@ def test_list_notebook_runtime_templates_empty_call(): with mock.patch.object( type(client.transport.list_notebook_runtime_templates), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_notebook_runtime_templates() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1914,6 +2112,9 @@ def test_list_notebook_runtime_templates_non_empty_request_with_auto_populated_f with mock.patch.object( type(client.transport.list_notebook_runtime_templates), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_notebook_runtime_templates(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1925,6 +2126,46 @@ def test_list_notebook_runtime_templates_non_empty_request_with_auto_populated_f ) +def test_list_notebook_runtime_templates_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_runtime_templates + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_runtime_templates + ] = mock_rpc + request = {} + client.list_notebook_runtime_templates(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_runtime_templates(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_notebook_runtime_templates_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1950,6 +2191,52 @@ async def test_list_notebook_runtime_templates_empty_call_async(): assert args[0] == notebook_service.ListNotebookRuntimeTemplatesRequest() +@pytest.mark.asyncio +async def test_list_notebook_runtime_templates_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_notebook_runtime_templates + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_notebook_runtime_templates + ] = mock_object + + request = {} + await client.list_notebook_runtime_templates(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_notebook_runtime_templates(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_notebook_runtime_templates_async( transport: str = "grpc_asyncio", @@ -2392,6 +2679,9 @@ def test_delete_notebook_runtime_template_empty_call(): with mock.patch.object( type(client.transport.delete_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_notebook_runtime_template() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2417,6 +2707,9 @@ def test_delete_notebook_runtime_template_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.delete_notebook_runtime_template), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_notebook_runtime_template(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2425,6 +2718,50 @@ def test_delete_notebook_runtime_template_non_empty_request_with_auto_populated_ ) +def test_delete_notebook_runtime_template_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_runtime_template + ] = mock_rpc + request = {} + client.delete_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_notebook_runtime_template_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2448,6 +2785,56 @@ async def test_delete_notebook_runtime_template_empty_call_async(): assert args[0] == notebook_service.DeleteNotebookRuntimeTemplateRequest() +@pytest.mark.asyncio +async def test_delete_notebook_runtime_template_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_notebook_runtime_template + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_notebook_runtime_template + ] = mock_object + + request = {} + await client.delete_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_notebook_runtime_template_async( transport: str = "grpc_asyncio", @@ -2685,6 +3072,9 @@ def test_assign_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.assign_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.assign_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2712,6 +3102,9 @@ def test_assign_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.assign_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.assign_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2722,6 +3115,50 @@ def test_assign_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_assign_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.assign_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.assign_notebook_runtime + ] = mock_rpc + request = {} + client.assign_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.assign_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_assign_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2745,6 +3182,56 @@ async def test_assign_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.AssignNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_assign_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.assign_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.assign_notebook_runtime + ] = mock_object + + request = {} + await client.assign_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.assign_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_assign_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -3042,6 +3529,9 @@ def test_get_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.get_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3067,6 +3557,9 @@ def test_get_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3075,6 +3568,45 @@ def test_get_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_get_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_runtime in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_runtime + ] = mock_rpc + request = {} + client.get_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3111,6 +3643,52 @@ async def test_get_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.GetNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_get_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_notebook_runtime + ] = mock_object + + request = {} + await client.get_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -3381,6 +3959,9 @@ def test_list_notebook_runtimes_empty_call(): with mock.patch.object( type(client.transport.list_notebook_runtimes), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_notebook_runtimes() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3409,6 +3990,9 @@ def test_list_notebook_runtimes_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_notebook_runtimes), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_notebook_runtimes(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3420,14 +4004,54 @@ def test_list_notebook_runtimes_non_empty_request_with_auto_populated_field(): ) -@pytest.mark.asyncio -async def test_list_notebook_runtimes_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = NotebookServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) +def test_list_notebook_runtimes_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_runtimes + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_runtimes + ] = mock_rpc + request = {} + client.list_notebook_runtimes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_runtimes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_notebook_runtimes_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( @@ -3445,6 +4069,52 @@ async def test_list_notebook_runtimes_empty_call_async(): assert args[0] == notebook_service.ListNotebookRuntimesRequest() +@pytest.mark.asyncio +async def test_list_notebook_runtimes_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_notebook_runtimes + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_notebook_runtimes + ] = mock_object + + request = {} + await client.list_notebook_runtimes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_notebook_runtimes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_notebook_runtimes_async( transport: str = "grpc_asyncio", @@ -3883,6 +4553,9 @@ def test_delete_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.delete_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3908,6 +4581,9 @@ def test_delete_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3916,6 +4592,50 @@ def test_delete_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_delete_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_runtime + ] = mock_rpc + request = {} + client.delete_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3939,6 +4659,56 @@ async def test_delete_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.DeleteNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_delete_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_notebook_runtime + ] = mock_object + + request = {} + await client.delete_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -4176,6 +4946,9 @@ def test_upgrade_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.upgrade_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upgrade_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4201,6 +4974,9 @@ def test_upgrade_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.upgrade_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upgrade_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4209,6 +4985,50 @@ def test_upgrade_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_upgrade_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.upgrade_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.upgrade_notebook_runtime + ] = mock_rpc + request = {} + client.upgrade_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.upgrade_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_upgrade_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4232,6 +5052,56 @@ async def test_upgrade_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.UpgradeNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_upgrade_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.upgrade_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.upgrade_notebook_runtime + ] = mock_object + + request = {} + await client.upgrade_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.upgrade_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_upgrade_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -4469,6 +5339,9 @@ def test_start_notebook_runtime_empty_call(): with mock.patch.object( type(client.transport.start_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.start_notebook_runtime() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4494,6 +5367,9 @@ def test_start_notebook_runtime_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.start_notebook_runtime), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.start_notebook_runtime(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4502,6 +5378,50 @@ def test_start_notebook_runtime_non_empty_request_with_auto_populated_field(): ) +def test_start_notebook_runtime_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.start_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.start_notebook_runtime + ] = mock_rpc + request = {} + client.start_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.start_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_start_notebook_runtime_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4525,6 +5445,56 @@ async def test_start_notebook_runtime_empty_call_async(): assert args[0] == notebook_service.StartNotebookRuntimeRequest() +@pytest.mark.asyncio +async def test_start_notebook_runtime_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.start_notebook_runtime + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.start_notebook_runtime + ] = mock_object + + request = {} + await client.start_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.start_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_start_notebook_runtime_async( transport: str = "grpc_asyncio", @@ -4718,125 +5688,2598 @@ async def test_start_notebook_runtime_flattened_error_async(): @pytest.mark.parametrize( "request_type", [ - notebook_service.CreateNotebookRuntimeTemplateRequest, + notebook_service.GetNotebookExecutionJobRequest, dict, ], ) -def test_create_notebook_runtime_template_rest(request_type): +def test_get_notebook_execution_job(request_type, transport: str = "grpc"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport="rest", + transport=transport, ) - # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} - request_init["notebook_runtime_template"] = { - "name": "name_value", - "display_name": "display_name_value", - "description": "description_value", - "is_default": True, - "machine_spec": { - "machine_type": "machine_type_value", - "accelerator_type": 1, - "accelerator_count": 1805, - "tpu_topology": "tpu_topology_value", - }, - "data_persistent_disk_spec": { - "disk_type": "disk_type_value", - "disk_size_gb": 1261, - }, - "network_spec": { - "enable_internet_access": True, - "network": "network_value", - "subnetwork": "subnetwork_value", - }, - "service_account": "service_account_value", - "etag": "etag_value", - "labels": {}, - "idle_shutdown_config": { - "idle_timeout": {"seconds": 751, "nanos": 543}, - "idle_shutdown_disabled": True, - }, - "euc_config": {"euc_disabled": True, "bypass_actas_check": True}, - "create_time": {"seconds": 751, "nanos": 543}, - "update_time": {}, - "notebook_runtime_type": 1, - "shielded_vm_config": {"enable_secure_boot": True}, - "network_tags": ["network_tags_value1", "network_tags_value2"], - } - # The version of a generated dependency at test runtime may differ from the version used during generation. - # Delete any fields which are not present in the current runtime dependency - # See https://github.com/googleapis/gapic-generator-python/issues/1748 - - # Determine if the message type is proto-plus or protobuf - test_field = notebook_service.CreateNotebookRuntimeTemplateRequest.meta.fields[ - "notebook_runtime_template" - ] + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() - def get_message_fields(field): - # Given a field which is a message (composite type), return a list with - # all the fields of the message. - # If the field is not a composite type, return an empty list. - message_fields = [] + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = notebook_execution_job.NotebookExecutionJob( + name="name_value", + display_name="display_name_value", + schedule_resource_name="schedule_resource_name_value", + job_state=job_state.JobState.JOB_STATE_QUEUED, + notebook_runtime_template_resource_name="notebook_runtime_template_resource_name_value", + gcs_output_uri="gcs_output_uri_value", + execution_user="execution_user_value", + ) + response = client.get_notebook_execution_job(request) - if hasattr(field, "message") and field.message: - is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = notebook_service.GetNotebookExecutionJobRequest() + assert args[0] == request - if is_field_type_proto_plus_type: - message_fields = field.message.meta.fields.values() - # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types - else: # pragma: NO COVER - message_fields = field.message.DESCRIPTOR.fields - return message_fields + # Establish that the response is the type that we expect. + assert isinstance(response, notebook_execution_job.NotebookExecutionJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.schedule_resource_name == "schedule_resource_name_value" + assert response.job_state == job_state.JobState.JOB_STATE_QUEUED - runtime_nested_fields = [ - (field.name, nested_field.name) - for field in get_message_fields(test_field) - for nested_field in get_message_fields(field) - ] - subfields_not_in_runtime = [] +def test_get_notebook_execution_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_notebook_execution_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.GetNotebookExecutionJobRequest() + + +def test_get_notebook_execution_job_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = notebook_service.GetNotebookExecutionJobRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_notebook_execution_job(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.GetNotebookExecutionJobRequest( + name="name_value", + ) + + +def test_get_notebook_execution_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_execution_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_execution_job + ] = mock_rpc + request = {} + client.get_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_execution_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_notebook_execution_job_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + notebook_execution_job.NotebookExecutionJob( + name="name_value", + display_name="display_name_value", + schedule_resource_name="schedule_resource_name_value", + job_state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + response = await client.get_notebook_execution_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.GetNotebookExecutionJobRequest() + + +@pytest.mark.asyncio +async def test_get_notebook_execution_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_notebook_execution_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_notebook_execution_job + ] = mock_object + + request = {} + await client.get_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_notebook_execution_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_notebook_execution_job_async( + transport: str = "grpc_asyncio", + request_type=notebook_service.GetNotebookExecutionJobRequest, +): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + notebook_execution_job.NotebookExecutionJob( + name="name_value", + display_name="display_name_value", + schedule_resource_name="schedule_resource_name_value", + job_state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + response = await client.get_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = notebook_service.GetNotebookExecutionJobRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, notebook_execution_job.NotebookExecutionJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.schedule_resource_name == "schedule_resource_name_value" + assert response.job_state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_notebook_execution_job_async_from_dict(): + await test_get_notebook_execution_job_async(request_type=dict) + + +def test_get_notebook_execution_job_field_headers(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = notebook_service.GetNotebookExecutionJobRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + call.return_value = notebook_execution_job.NotebookExecutionJob() + client.get_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_notebook_execution_job_field_headers_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = notebook_service.GetNotebookExecutionJobRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + notebook_execution_job.NotebookExecutionJob() + ) + await client.get_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_notebook_execution_job_flattened(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = notebook_execution_job.NotebookExecutionJob() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_notebook_execution_job( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_notebook_execution_job_flattened_error(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_notebook_execution_job( + notebook_service.GetNotebookExecutionJobRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_notebook_execution_job_flattened_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = notebook_execution_job.NotebookExecutionJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + notebook_execution_job.NotebookExecutionJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_notebook_execution_job( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_notebook_execution_job_flattened_error_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_notebook_execution_job( + notebook_service.GetNotebookExecutionJobRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + notebook_service.ListNotebookExecutionJobsRequest, + dict, + ], +) +def test_list_notebook_execution_jobs(request_type, transport: str = "grpc"): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = notebook_service.ListNotebookExecutionJobsResponse( + next_page_token="next_page_token_value", + ) + response = client.list_notebook_execution_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = notebook_service.ListNotebookExecutionJobsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListNotebookExecutionJobsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_notebook_execution_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_notebook_execution_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.ListNotebookExecutionJobsRequest() + + +def test_list_notebook_execution_jobs_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = notebook_service.ListNotebookExecutionJobsRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_notebook_execution_jobs(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.ListNotebookExecutionJobsRequest( + parent="parent_value", + filter="filter_value", + page_token="page_token_value", + order_by="order_by_value", + ) + + +def test_list_notebook_execution_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_execution_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_execution_jobs + ] = mock_rpc + request = {} + client.list_notebook_execution_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_execution_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + notebook_service.ListNotebookExecutionJobsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_notebook_execution_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.ListNotebookExecutionJobsRequest() + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_notebook_execution_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_notebook_execution_jobs + ] = mock_object + + request = {} + await client.list_notebook_execution_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_notebook_execution_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_async( + transport: str = "grpc_asyncio", + request_type=notebook_service.ListNotebookExecutionJobsRequest, +): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + notebook_service.ListNotebookExecutionJobsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_notebook_execution_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = notebook_service.ListNotebookExecutionJobsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListNotebookExecutionJobsAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_async_from_dict(): + await test_list_notebook_execution_jobs_async(request_type=dict) + + +def test_list_notebook_execution_jobs_field_headers(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = notebook_service.ListNotebookExecutionJobsRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + call.return_value = notebook_service.ListNotebookExecutionJobsResponse() + client.list_notebook_execution_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_field_headers_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = notebook_service.ListNotebookExecutionJobsRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + notebook_service.ListNotebookExecutionJobsResponse() + ) + await client.list_notebook_execution_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_list_notebook_execution_jobs_flattened(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = notebook_service.ListNotebookExecutionJobsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_notebook_execution_jobs( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +def test_list_notebook_execution_jobs_flattened_error(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_notebook_execution_jobs( + notebook_service.ListNotebookExecutionJobsRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_flattened_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = notebook_service.ListNotebookExecutionJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + notebook_service.ListNotebookExecutionJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_notebook_execution_jobs( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_flattened_error_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_notebook_execution_jobs( + notebook_service.ListNotebookExecutionJobsRequest(), + parent="parent_value", + ) + + +def test_list_notebook_execution_jobs_pager(transport_name: str = "grpc"): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="abc", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[], + next_page_token="def", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="ghi", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_notebook_execution_jobs(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + assert all( + isinstance(i, notebook_execution_job.NotebookExecutionJob) for i in results + ) + + +def test_list_notebook_execution_jobs_pages(transport_name: str = "grpc"): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="abc", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[], + next_page_token="def", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="ghi", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_notebook_execution_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_async_pager(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="abc", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[], + next_page_token="def", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="ghi", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_notebook_execution_jobs( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, notebook_execution_job.NotebookExecutionJob) + for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_notebook_execution_jobs_async_pages(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_notebook_execution_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="abc", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[], + next_page_token="def", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="ghi", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_notebook_execution_jobs(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + notebook_service.DeleteNotebookExecutionJobRequest, + dict, + ], +) +def test_delete_notebook_execution_job(request_type, transport: str = "grpc"): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = notebook_service.DeleteNotebookExecutionJobRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_notebook_execution_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_notebook_execution_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.DeleteNotebookExecutionJobRequest() + + +def test_delete_notebook_execution_job_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = notebook_service.DeleteNotebookExecutionJobRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_notebook_execution_job(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.DeleteNotebookExecutionJobRequest( + name="name_value", + ) + + +def test_delete_notebook_execution_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_execution_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_execution_job + ] = mock_rpc + request = {} + client.delete_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_execution_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_notebook_execution_job_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_notebook_execution_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == notebook_service.DeleteNotebookExecutionJobRequest() + + +@pytest.mark.asyncio +async def test_delete_notebook_execution_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_notebook_execution_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_notebook_execution_job + ] = mock_object + + request = {} + await client.delete_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_notebook_execution_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_notebook_execution_job_async( + transport: str = "grpc_asyncio", + request_type=notebook_service.DeleteNotebookExecutionJobRequest, +): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = notebook_service.DeleteNotebookExecutionJobRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_notebook_execution_job_async_from_dict(): + await test_delete_notebook_execution_job_async(request_type=dict) + + +def test_delete_notebook_execution_job_field_headers(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = notebook_service.DeleteNotebookExecutionJobRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_notebook_execution_job_field_headers_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = notebook_service.DeleteNotebookExecutionJobRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.delete_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_notebook_execution_job_flattened(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_notebook_execution_job( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_notebook_execution_job_flattened_error(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_notebook_execution_job( + notebook_service.DeleteNotebookExecutionJobRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_notebook_execution_job_flattened_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_notebook_execution_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_notebook_execution_job( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_notebook_execution_job_flattened_error_async(): + client = NotebookServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_notebook_execution_job( + notebook_service.DeleteNotebookExecutionJobRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + notebook_service.CreateNotebookRuntimeTemplateRequest, + dict, + ], +) +def test_create_notebook_runtime_template_rest(request_type): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["notebook_runtime_template"] = { + "name": "name_value", + "display_name": "display_name_value", + "description": "description_value", + "is_default": True, + "machine_spec": { + "machine_type": "machine_type_value", + "accelerator_type": 1, + "accelerator_count": 1805, + "tpu_topology": "tpu_topology_value", + }, + "data_persistent_disk_spec": { + "disk_type": "disk_type_value", + "disk_size_gb": 1261, + }, + "network_spec": { + "enable_internet_access": True, + "network": "network_value", + "subnetwork": "subnetwork_value", + }, + "service_account": "service_account_value", + "etag": "etag_value", + "labels": {}, + "idle_shutdown_config": { + "idle_timeout": {"seconds": 751, "nanos": 543}, + "idle_shutdown_disabled": True, + }, + "euc_config": {"euc_disabled": True, "bypass_actas_check": True}, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "notebook_runtime_type": 1, + "shielded_vm_config": {"enable_secure_boot": True}, + "network_tags": ["network_tags_value1", "network_tags_value2"], + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = notebook_service.CreateNotebookRuntimeTemplateRequest.meta.fields[ + "notebook_runtime_template" + ] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init[ + "notebook_runtime_template" + ].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range( + 0, len(request_init["notebook_runtime_template"][field]) + ): + del request_init["notebook_runtime_template"][field][i][subfield] + else: + del request_init["notebook_runtime_template"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_notebook_runtime_template(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_notebook_runtime_template_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_notebook_runtime_template + ] = mock_rpc + + request = {} + client.create_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_notebook_runtime_template_rest_required_fields( + request_type=notebook_service.CreateNotebookRuntimeTemplateRequest, +): + transport_class = transports.NotebookServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_notebook_runtime_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_notebook_runtime_template._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("notebook_runtime_template_id",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_notebook_runtime_template(request) + + expected_params = [] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_notebook_runtime_template_rest_unset_required_fields(): + transport = transports.NotebookServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = ( + transport.create_notebook_runtime_template._get_unset_required_fields({}) + ) + assert set(unset_fields) == ( + set(("notebookRuntimeTemplateId",)) + & set( + ( + "parent", + "notebookRuntimeTemplate", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_notebook_runtime_template_rest_interceptors(null_interceptor): + transport = transports.NotebookServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.NotebookServiceRestInterceptor(), + ) + client = NotebookServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.NotebookServiceRestInterceptor, + "post_create_notebook_runtime_template", + ) as post, mock.patch.object( + transports.NotebookServiceRestInterceptor, + "pre_create_notebook_runtime_template", + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = notebook_service.CreateNotebookRuntimeTemplateRequest.pb( + notebook_service.CreateNotebookRuntimeTemplateRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = notebook_service.CreateNotebookRuntimeTemplateRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.create_notebook_runtime_template( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_notebook_runtime_template_rest_bad_request( + transport: str = "rest", + request_type=notebook_service.CreateNotebookRuntimeTemplateRequest, +): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_notebook_runtime_template(request) + + +def test_create_notebook_runtime_template_rest_flattened(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + notebook_runtime_template=notebook_runtime.NotebookRuntimeTemplate( + name="name_value" + ), + notebook_runtime_template_id="notebook_runtime_template_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.create_notebook_runtime_template(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1beta1/{parent=projects/*/locations/*}/notebookRuntimeTemplates" + % client.transport._host, + args[1], + ) + + +def test_create_notebook_runtime_template_rest_flattened_error(transport: str = "rest"): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_notebook_runtime_template( + notebook_service.CreateNotebookRuntimeTemplateRequest(), + parent="parent_value", + notebook_runtime_template=notebook_runtime.NotebookRuntimeTemplate( + name="name_value" + ), + notebook_runtime_template_id="notebook_runtime_template_id_value", + ) + + +def test_create_notebook_runtime_template_rest_error(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + notebook_service.GetNotebookRuntimeTemplateRequest, + dict, + ], +) +def test_get_notebook_runtime_template_rest(request_type): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = notebook_runtime.NotebookRuntimeTemplate( + name="name_value", + display_name="display_name_value", + description="description_value", + is_default=True, + service_account="service_account_value", + etag="etag_value", + notebook_runtime_type=notebook_runtime.NotebookRuntimeType.USER_DEFINED, + network_tags=["network_tags_value"], + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_runtime.NotebookRuntimeTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_notebook_runtime_template(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, notebook_runtime.NotebookRuntimeTemplate) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.is_default is True + assert response.service_account == "service_account_value" + assert response.etag == "etag_value" + assert ( + response.notebook_runtime_type + == notebook_runtime.NotebookRuntimeType.USER_DEFINED + ) + assert response.network_tags == ["network_tags_value"] + + +def test_get_notebook_runtime_template_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_runtime_template + ] = mock_rpc + + request = {} + client.get_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_runtime_template(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_notebook_runtime_template_rest_required_fields( + request_type=notebook_service.GetNotebookRuntimeTemplateRequest, +): + transport_class = transports.NotebookServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_notebook_runtime_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_notebook_runtime_template._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = notebook_runtime.NotebookRuntimeTemplate() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = notebook_runtime.NotebookRuntimeTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_notebook_runtime_template(request) + + expected_params = [] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_notebook_runtime_template_rest_unset_required_fields(): + transport = transports.NotebookServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_notebook_runtime_template._get_unset_required_fields( + {} + ) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_notebook_runtime_template_rest_interceptors(null_interceptor): + transport = transports.NotebookServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.NotebookServiceRestInterceptor(), + ) + client = NotebookServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.NotebookServiceRestInterceptor, "post_get_notebook_runtime_template" + ) as post, mock.patch.object( + transports.NotebookServiceRestInterceptor, "pre_get_notebook_runtime_template" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = notebook_service.GetNotebookRuntimeTemplateRequest.pb( + notebook_service.GetNotebookRuntimeTemplateRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = notebook_runtime.NotebookRuntimeTemplate.to_json( + notebook_runtime.NotebookRuntimeTemplate() + ) + + request = notebook_service.GetNotebookRuntimeTemplateRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = notebook_runtime.NotebookRuntimeTemplate() + + client.get_notebook_runtime_template( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_notebook_runtime_template_rest_bad_request( + transport: str = "rest", + request_type=notebook_service.GetNotebookRuntimeTemplateRequest, +): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_notebook_runtime_template(request) + + +def test_get_notebook_runtime_template_rest_flattened(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = notebook_runtime.NotebookRuntimeTemplate() + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_runtime.NotebookRuntimeTemplate.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_notebook_runtime_template(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimeTemplates/*}" + % client.transport._host, + args[1], + ) + + +def test_get_notebook_runtime_template_rest_flattened_error(transport: str = "rest"): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_notebook_runtime_template( + notebook_service.GetNotebookRuntimeTemplateRequest(), + name="name_value", + ) + + +def test_get_notebook_runtime_template_rest_error(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + notebook_service.ListNotebookRuntimeTemplatesRequest, + dict, + ], +) +def test_list_notebook_runtime_templates_rest(request_type): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = notebook_service.ListNotebookRuntimeTemplatesResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookRuntimeTemplatesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_notebook_runtime_templates(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListNotebookRuntimeTemplatesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_notebook_runtime_templates_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_runtime_templates + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_runtime_templates + ] = mock_rpc + + request = {} + client.list_notebook_runtime_templates(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_runtime_templates(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_notebook_runtime_templates_rest_required_fields( + request_type=notebook_service.ListNotebookRuntimeTemplatesRequest, +): + transport_class = transports.NotebookServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_notebook_runtime_templates._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_notebook_runtime_templates._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "filter", + "order_by", + "page_size", + "page_token", + "read_mask", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = notebook_service.ListNotebookRuntimeTemplatesResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookRuntimeTemplatesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_notebook_runtime_templates(request) + + expected_params = [] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_notebook_runtime_templates_rest_unset_required_fields(): + transport = transports.NotebookServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_notebook_runtime_templates._get_unset_required_fields( + {} + ) + assert set(unset_fields) == ( + set( + ( + "filter", + "orderBy", + "pageSize", + "pageToken", + "readMask", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_notebook_runtime_templates_rest_interceptors(null_interceptor): + transport = transports.NotebookServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.NotebookServiceRestInterceptor(), + ) + client = NotebookServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.NotebookServiceRestInterceptor, + "post_list_notebook_runtime_templates", + ) as post, mock.patch.object( + transports.NotebookServiceRestInterceptor, "pre_list_notebook_runtime_templates" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = notebook_service.ListNotebookRuntimeTemplatesRequest.pb( + notebook_service.ListNotebookRuntimeTemplatesRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = ( + notebook_service.ListNotebookRuntimeTemplatesResponse.to_json( + notebook_service.ListNotebookRuntimeTemplatesResponse() + ) + ) - # For each item in the sample request, create a list of sub fields which are not present at runtime - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for field, value in request_init[ - "notebook_runtime_template" - ].items(): # pragma: NO COVER - result = None - is_repeated = False - # For repeated fields - if isinstance(value, list) and len(value): - is_repeated = True - result = value[0] - # For fields where the type is another message - if isinstance(value, dict): - result = value + request = notebook_service.ListNotebookRuntimeTemplatesRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = notebook_service.ListNotebookRuntimeTemplatesResponse() - if result and hasattr(result, "keys"): - for subfield in result.keys(): - if (field, subfield) not in runtime_nested_fields: - subfields_not_in_runtime.append( - { - "field": field, - "subfield": subfield, - "is_repeated": is_repeated, - } - ) + client.list_notebook_runtime_templates( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) - # Remove fields from the sample request which are not present in the runtime version of the dependency - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER - field = subfield_to_delete.get("field") - field_repeated = subfield_to_delete.get("is_repeated") - subfield = subfield_to_delete.get("subfield") - if subfield: - if field_repeated: - for i in range( - 0, len(request_init["notebook_runtime_template"][field]) - ): - del request_init["notebook_runtime_template"][field][i][subfield] - else: - del request_init["notebook_runtime_template"][field][subfield] + pre.assert_called_once() + post.assert_called_once() + + +def test_list_notebook_runtime_templates_rest_bad_request( + transport: str = "rest", + request_type=notebook_service.ListNotebookRuntimeTemplatesRequest, +): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_notebook_runtime_templates(request) + + +def test_list_notebook_runtime_templates_rest_flattened(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = notebook_service.ListNotebookRuntimeTemplatesResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookRuntimeTemplatesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_notebook_runtime_templates(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1beta1/{parent=projects/*/locations/*}/notebookRuntimeTemplates" + % client.transport._host, + args[1], + ) + + +def test_list_notebook_runtime_templates_rest_flattened_error(transport: str = "rest"): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_notebook_runtime_templates( + notebook_service.ListNotebookRuntimeTemplatesRequest(), + parent="parent_value", + ) + + +def test_list_notebook_runtime_templates_rest_pager(transport: str = "rest"): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + notebook_service.ListNotebookRuntimeTemplatesResponse( + notebook_runtime_templates=[ + notebook_runtime.NotebookRuntimeTemplate(), + notebook_runtime.NotebookRuntimeTemplate(), + notebook_runtime.NotebookRuntimeTemplate(), + ], + next_page_token="abc", + ), + notebook_service.ListNotebookRuntimeTemplatesResponse( + notebook_runtime_templates=[], + next_page_token="def", + ), + notebook_service.ListNotebookRuntimeTemplatesResponse( + notebook_runtime_templates=[ + notebook_runtime.NotebookRuntimeTemplate(), + ], + next_page_token="ghi", + ), + notebook_service.ListNotebookRuntimeTemplatesResponse( + notebook_runtime_templates=[ + notebook_runtime.NotebookRuntimeTemplate(), + notebook_runtime.NotebookRuntimeTemplate(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + notebook_service.ListNotebookRuntimeTemplatesResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/locations/sample2"} + + pager = client.list_notebook_runtime_templates(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all( + isinstance(i, notebook_runtime.NotebookRuntimeTemplate) for i in results + ) + + pages = list( + client.list_notebook_runtime_templates(request=sample_request).pages + ) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + notebook_service.DeleteNotebookRuntimeTemplateRequest, + dict, + ], +) +def test_delete_notebook_runtime_template_rest(request_type): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -4844,26 +8287,71 @@ def get_message_fields(field): # Designate an appropriate value for the returned response. return_value = operations_pb2.Operation(name="operations/spam") - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - json_return_value = json_format.MessageToJson(return_value) + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_notebook_runtime_template(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_delete_notebook_runtime_template_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_runtime_template + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_runtime_template + ] = mock_rpc + + request = {} + client.delete_notebook_runtime_template(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.create_notebook_runtime_template(request) + client.delete_notebook_runtime_template(request) - # Establish that the response is the type that we expect. - assert response.operation.name == "operations/spam" + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 -def test_create_notebook_runtime_template_rest_required_fields( - request_type=notebook_service.CreateNotebookRuntimeTemplateRequest, +def test_delete_notebook_runtime_template_rest_required_fields( + request_type=notebook_service.DeleteNotebookRuntimeTemplateRequest, ): transport_class = transports.NotebookServiceRestTransport request_init = {} - request_init["parent"] = "" + request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -4874,23 +8362,21 @@ def test_create_notebook_runtime_template_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_notebook_runtime_template._get_unset_required_fields(jsonified_request) + ).delete_notebook_runtime_template._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["name"] = "name_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_notebook_runtime_template._get_unset_required_fields(jsonified_request) - # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set(("notebook_runtime_template_id",)) + ).delete_notebook_runtime_template._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -4911,10 +8397,9 @@ def test_create_notebook_runtime_template_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "delete", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -4924,34 +8409,26 @@ def test_create_notebook_runtime_template_rest_required_fields( response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.create_notebook_runtime_template(request) + response = client.delete_notebook_runtime_template(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_create_notebook_runtime_template_rest_unset_required_fields(): +def test_delete_notebook_runtime_template_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) unset_fields = ( - transport.create_notebook_runtime_template._get_unset_required_fields({}) - ) - assert set(unset_fields) == ( - set(("notebookRuntimeTemplateId",)) - & set( - ( - "parent", - "notebookRuntimeTemplate", - ) - ) + transport.delete_notebook_runtime_template._get_unset_required_fields({}) ) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_create_notebook_runtime_template_rest_interceptors(null_interceptor): +def test_delete_notebook_runtime_template_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -4967,15 +8444,15 @@ def test_create_notebook_runtime_template_rest_interceptors(null_interceptor): operation.Operation, "_set_result_from_operation" ), mock.patch.object( transports.NotebookServiceRestInterceptor, - "post_create_notebook_runtime_template", + "post_delete_notebook_runtime_template", ) as post, mock.patch.object( transports.NotebookServiceRestInterceptor, - "pre_create_notebook_runtime_template", + "pre_delete_notebook_runtime_template", ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.CreateNotebookRuntimeTemplateRequest.pb( - notebook_service.CreateNotebookRuntimeTemplateRequest() + pb_message = notebook_service.DeleteNotebookRuntimeTemplateRequest.pb( + notebook_service.DeleteNotebookRuntimeTemplateRequest() ) transcode.return_value = { "method": "post", @@ -4991,7 +8468,7 @@ def test_create_notebook_runtime_template_rest_interceptors(null_interceptor): operations_pb2.Operation() ) - request = notebook_service.CreateNotebookRuntimeTemplateRequest() + request = notebook_service.DeleteNotebookRuntimeTemplateRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -4999,7 +8476,7 @@ def test_create_notebook_runtime_template_rest_interceptors(null_interceptor): pre.return_value = request, metadata post.return_value = operations_pb2.Operation() - client.create_notebook_runtime_template( + client.delete_notebook_runtime_template( request, metadata=[ ("key", "val"), @@ -5011,9 +8488,9 @@ def test_create_notebook_runtime_template_rest_interceptors(null_interceptor): post.assert_called_once() -def test_create_notebook_runtime_template_rest_bad_request( +def test_delete_notebook_runtime_template_rest_bad_request( transport: str = "rest", - request_type=notebook_service.CreateNotebookRuntimeTemplateRequest, + request_type=notebook_service.DeleteNotebookRuntimeTemplateRequest, ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5021,7 +8498,9 @@ def test_create_notebook_runtime_template_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -5033,10 +8512,10 @@ def test_create_notebook_runtime_template_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.create_notebook_runtime_template(request) + client.delete_notebook_runtime_template(request) -def test_create_notebook_runtime_template_rest_flattened(): +def test_delete_notebook_runtime_template_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5048,15 +8527,13 @@ def test_create_notebook_runtime_template_rest_flattened(): return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" + } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", - notebook_runtime_template=notebook_runtime.NotebookRuntimeTemplate( - name="name_value" - ), - notebook_runtime_template_id="notebook_runtime_template_id_value", + name="name_value", ) mock_args.update(sample_request) @@ -5067,20 +8544,20 @@ def test_create_notebook_runtime_template_rest_flattened(): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.create_notebook_runtime_template(**mock_args) + client.delete_notebook_runtime_template(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/notebookRuntimeTemplates" + "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimeTemplates/*}" % client.transport._host, args[1], ) -def test_create_notebook_runtime_template_rest_flattened_error(transport: str = "rest"): +def test_delete_notebook_runtime_template_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5089,17 +8566,13 @@ def test_create_notebook_runtime_template_rest_flattened_error(transport: str = # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.create_notebook_runtime_template( - notebook_service.CreateNotebookRuntimeTemplateRequest(), - parent="parent_value", - notebook_runtime_template=notebook_runtime.NotebookRuntimeTemplate( - name="name_value" - ), - notebook_runtime_template_id="notebook_runtime_template_id_value", + client.delete_notebook_runtime_template( + notebook_service.DeleteNotebookRuntimeTemplateRequest(), + name="name_value", ) -def test_create_notebook_runtime_template_rest_error(): +def test_delete_notebook_runtime_template_rest_error(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -5108,69 +8581,91 @@ def test_create_notebook_runtime_template_rest_error(): @pytest.mark.parametrize( "request_type", [ - notebook_service.GetNotebookRuntimeTemplateRequest, + notebook_service.AssignNotebookRuntimeRequest, dict, ], ) -def test_get_notebook_runtime_template_rest(request_type): +def test_assign_notebook_runtime_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" - } + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = notebook_runtime.NotebookRuntimeTemplate( - name="name_value", - display_name="display_name_value", - description="description_value", - is_default=True, - service_account="service_account_value", - etag="etag_value", - notebook_runtime_type=notebook_runtime.NotebookRuntimeType.USER_DEFINED, - network_tags=["network_tags_value"], - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = notebook_runtime.NotebookRuntimeTemplate.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.get_notebook_runtime_template(request) + response = client.assign_notebook_runtime(request) # Establish that the response is the type that we expect. - assert isinstance(response, notebook_runtime.NotebookRuntimeTemplate) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.is_default is True - assert response.service_account == "service_account_value" - assert response.etag == "etag_value" - assert ( - response.notebook_runtime_type - == notebook_runtime.NotebookRuntimeType.USER_DEFINED - ) - assert response.network_tags == ["network_tags_value"] + assert response.operation.name == "operations/spam" -def test_get_notebook_runtime_template_rest_required_fields( - request_type=notebook_service.GetNotebookRuntimeTemplateRequest, +def test_assign_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.assign_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.assign_notebook_runtime + ] = mock_rpc + + request = {} + client.assign_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.assign_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_assign_notebook_runtime_rest_required_fields( + request_type=notebook_service.AssignNotebookRuntimeRequest, ): transport_class = transports.NotebookServiceRestTransport request_init = {} - request_init["name"] = "" + request_init["parent"] = "" + request_init["notebook_runtime_template"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -5181,21 +8676,27 @@ def test_get_notebook_runtime_template_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_notebook_runtime_template._get_unset_required_fields(jsonified_request) + ).assign_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = "name_value" + jsonified_request["parent"] = "parent_value" + jsonified_request["notebookRuntimeTemplate"] = "notebook_runtime_template_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_notebook_runtime_template._get_unset_required_fields(jsonified_request) + ).assign_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "name" in jsonified_request - assert jsonified_request["name"] == "name_value" + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + assert "notebookRuntimeTemplate" in jsonified_request + assert ( + jsonified_request["notebookRuntimeTemplate"] + == "notebook_runtime_template_value" + ) client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5204,7 +8705,7 @@ def test_get_notebook_runtime_template_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = notebook_runtime.NotebookRuntimeTemplate() + return_value = operations_pb2.Operation(name="operations/spam") # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5216,41 +8717,46 @@ def test_get_notebook_runtime_template_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "post", "query_params": pb_request, } + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = notebook_runtime.NotebookRuntimeTemplate.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.get_notebook_runtime_template(request) + response = client.assign_notebook_runtime(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_get_notebook_runtime_template_rest_unset_required_fields(): +def test_assign_notebook_runtime_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.get_notebook_runtime_template._get_unset_required_fields( - {} + unset_fields = transport.assign_notebook_runtime._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "notebookRuntimeTemplate", + "notebookRuntime", + ) + ) ) - assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_get_notebook_runtime_template_rest_interceptors(null_interceptor): +def test_assign_notebook_runtime_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -5263,14 +8769,16 @@ def test_get_notebook_runtime_template_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.NotebookServiceRestInterceptor, "post_get_notebook_runtime_template" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.NotebookServiceRestInterceptor, "post_assign_notebook_runtime" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, "pre_get_notebook_runtime_template" + transports.NotebookServiceRestInterceptor, "pre_assign_notebook_runtime" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.GetNotebookRuntimeTemplateRequest.pb( - notebook_service.GetNotebookRuntimeTemplateRequest() + pb_message = notebook_service.AssignNotebookRuntimeRequest.pb( + notebook_service.AssignNotebookRuntimeRequest() ) transcode.return_value = { "method": "post", @@ -5282,19 +8790,19 @@ def test_get_notebook_runtime_template_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = notebook_runtime.NotebookRuntimeTemplate.to_json( - notebook_runtime.NotebookRuntimeTemplate() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() ) - request = notebook_service.GetNotebookRuntimeTemplateRequest() + request = notebook_service.AssignNotebookRuntimeRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = notebook_runtime.NotebookRuntimeTemplate() + post.return_value = operations_pb2.Operation() - client.get_notebook_runtime_template( + client.assign_notebook_runtime( request, metadata=[ ("key", "val"), @@ -5306,9 +8814,8 @@ def test_get_notebook_runtime_template_rest_interceptors(null_interceptor): post.assert_called_once() -def test_get_notebook_runtime_template_rest_bad_request( - transport: str = "rest", - request_type=notebook_service.GetNotebookRuntimeTemplateRequest, +def test_assign_notebook_runtime_rest_bad_request( + transport: str = "rest", request_type=notebook_service.AssignNotebookRuntimeRequest ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5316,9 +8823,7 @@ def test_get_notebook_runtime_template_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" - } + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -5330,10 +8835,10 @@ def test_get_notebook_runtime_template_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.get_notebook_runtime_template(request) + client.assign_notebook_runtime(request) -def test_get_notebook_runtime_template_rest_flattened(): +def test_assign_notebook_runtime_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5342,42 +8847,41 @@ def test_get_notebook_runtime_template_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = notebook_runtime.NotebookRuntimeTemplate() + return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = { - "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" - } + sample_request = {"parent": "projects/sample1/locations/sample2"} # get truthy value for each flattened field mock_args = dict( - name="name_value", + parent="parent_value", + notebook_runtime_template="notebook_runtime_template_value", + notebook_runtime=gca_notebook_runtime.NotebookRuntime(name="name_value"), + notebook_runtime_id="notebook_runtime_id_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = notebook_runtime.NotebookRuntimeTemplate.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.get_notebook_runtime_template(**mock_args) + client.assign_notebook_runtime(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimeTemplates/*}" + "%s/v1beta1/{parent=projects/*/locations/*}/notebookRuntimes:assign" % client.transport._host, args[1], ) -def test_get_notebook_runtime_template_rest_flattened_error(transport: str = "rest"): +def test_assign_notebook_runtime_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5386,13 +8890,16 @@ def test_get_notebook_runtime_template_rest_flattened_error(transport: str = "re # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_notebook_runtime_template( - notebook_service.GetNotebookRuntimeTemplateRequest(), - name="name_value", + client.assign_notebook_runtime( + notebook_service.AssignNotebookRuntimeRequest(), + parent="parent_value", + notebook_runtime_template="notebook_runtime_template_value", + notebook_runtime=gca_notebook_runtime.NotebookRuntime(name="name_value"), + notebook_runtime_id="notebook_runtime_id_value", ) -def test_get_notebook_runtime_template_rest_error(): +def test_assign_notebook_runtime_rest_error(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -5401,52 +8908,119 @@ def test_get_notebook_runtime_template_rest_error(): @pytest.mark.parametrize( "request_type", [ - notebook_service.ListNotebookRuntimeTemplatesRequest, + notebook_service.GetNotebookRuntimeRequest, dict, ], ) -def test_list_notebook_runtime_templates_rest(request_type): +def test_get_notebook_runtime_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = notebook_service.ListNotebookRuntimeTemplatesResponse( - next_page_token="next_page_token_value", + return_value = notebook_runtime.NotebookRuntime( + name="name_value", + runtime_user="runtime_user_value", + proxy_uri="proxy_uri_value", + health_state=notebook_runtime.NotebookRuntime.HealthState.HEALTHY, + display_name="display_name_value", + description="description_value", + service_account="service_account_value", + runtime_state=notebook_runtime.NotebookRuntime.RuntimeState.RUNNING, + is_upgradable=True, + version="version_value", + notebook_runtime_type=notebook_runtime.NotebookRuntimeType.USER_DEFINED, + network_tags=["network_tags_value"], ) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = notebook_service.ListNotebookRuntimeTemplatesResponse.pb( - return_value - ) + return_value = notebook_runtime.NotebookRuntime.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.list_notebook_runtime_templates(request) + response = client.get_notebook_runtime(request) # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListNotebookRuntimeTemplatesPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, notebook_runtime.NotebookRuntime) + assert response.name == "name_value" + assert response.runtime_user == "runtime_user_value" + assert response.proxy_uri == "proxy_uri_value" + assert response.health_state == notebook_runtime.NotebookRuntime.HealthState.HEALTHY + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.service_account == "service_account_value" + assert ( + response.runtime_state == notebook_runtime.NotebookRuntime.RuntimeState.RUNNING + ) + assert response.is_upgradable is True + assert response.version == "version_value" + assert ( + response.notebook_runtime_type + == notebook_runtime.NotebookRuntimeType.USER_DEFINED + ) + assert response.network_tags == ["network_tags_value"] -def test_list_notebook_runtime_templates_rest_required_fields( - request_type=notebook_service.ListNotebookRuntimeTemplatesRequest, +def test_get_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_runtime in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_runtime + ] = mock_rpc + + request = {} + client.get_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_notebook_runtime_rest_required_fields( + request_type=notebook_service.GetNotebookRuntimeRequest, ): transport_class = transports.NotebookServiceRestTransport request_init = {} - request_init["parent"] = "" + request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -5457,31 +9031,21 @@ def test_list_notebook_runtime_templates_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_notebook_runtime_templates._get_unset_required_fields(jsonified_request) + ).get_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["name"] = "name_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_notebook_runtime_templates._get_unset_required_fields(jsonified_request) - # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set( - ( - "filter", - "order_by", - "page_size", - "page_token", - "read_mask", - ) - ) + ).get_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5490,7 +9054,7 @@ def test_list_notebook_runtime_templates_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = notebook_service.ListNotebookRuntimeTemplatesResponse() + return_value = notebook_runtime.NotebookRuntime() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5511,45 +9075,30 @@ def test_list_notebook_runtime_templates_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = notebook_service.ListNotebookRuntimeTemplatesResponse.pb( - return_value - ) + return_value = notebook_runtime.NotebookRuntime.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.list_notebook_runtime_templates(request) + response = client.get_notebook_runtime(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_list_notebook_runtime_templates_rest_unset_required_fields(): +def test_get_notebook_runtime_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.list_notebook_runtime_templates._get_unset_required_fields( - {} - ) - assert set(unset_fields) == ( - set( - ( - "filter", - "orderBy", - "pageSize", - "pageToken", - "readMask", - ) - ) - & set(("parent",)) - ) + unset_fields = transport.get_notebook_runtime._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_list_notebook_runtime_templates_rest_interceptors(null_interceptor): +def test_get_notebook_runtime_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -5562,15 +9111,14 @@ def test_list_notebook_runtime_templates_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.NotebookServiceRestInterceptor, - "post_list_notebook_runtime_templates", + transports.NotebookServiceRestInterceptor, "post_get_notebook_runtime" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, "pre_list_notebook_runtime_templates" + transports.NotebookServiceRestInterceptor, "pre_get_notebook_runtime" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.ListNotebookRuntimeTemplatesRequest.pb( - notebook_service.ListNotebookRuntimeTemplatesRequest() + pb_message = notebook_service.GetNotebookRuntimeRequest.pb( + notebook_service.GetNotebookRuntimeRequest() ) transcode.return_value = { "method": "post", @@ -5582,21 +9130,19 @@ def test_list_notebook_runtime_templates_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = ( - notebook_service.ListNotebookRuntimeTemplatesResponse.to_json( - notebook_service.ListNotebookRuntimeTemplatesResponse() - ) + req.return_value._content = notebook_runtime.NotebookRuntime.to_json( + notebook_runtime.NotebookRuntime() ) - request = notebook_service.ListNotebookRuntimeTemplatesRequest() + request = notebook_service.GetNotebookRuntimeRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = notebook_service.ListNotebookRuntimeTemplatesResponse() + post.return_value = notebook_runtime.NotebookRuntime() - client.list_notebook_runtime_templates( + client.get_notebook_runtime( request, metadata=[ ("key", "val"), @@ -5608,9 +9154,8 @@ def test_list_notebook_runtime_templates_rest_interceptors(null_interceptor): post.assert_called_once() -def test_list_notebook_runtime_templates_rest_bad_request( - transport: str = "rest", - request_type=notebook_service.ListNotebookRuntimeTemplatesRequest, +def test_get_notebook_runtime_rest_bad_request( + transport: str = "rest", request_type=notebook_service.GetNotebookRuntimeRequest ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5618,7 +9163,9 @@ def test_list_notebook_runtime_templates_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -5630,10 +9177,10 @@ def test_list_notebook_runtime_templates_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.list_notebook_runtime_templates(request) + client.get_notebook_runtime(request) -def test_list_notebook_runtime_templates_rest_flattened(): +def test_get_notebook_runtime_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5642,14 +9189,16 @@ def test_list_notebook_runtime_templates_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = notebook_service.ListNotebookRuntimeTemplatesResponse() + return_value = notebook_runtime.NotebookRuntime() # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", + name="name_value", ) mock_args.update(sample_request) @@ -5657,27 +9206,25 @@ def test_list_notebook_runtime_templates_rest_flattened(): response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = notebook_service.ListNotebookRuntimeTemplatesResponse.pb( - return_value - ) + return_value = notebook_runtime.NotebookRuntime.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.list_notebook_runtime_templates(**mock_args) + client.get_notebook_runtime(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/notebookRuntimeTemplates" + "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimes/*}" % client.transport._host, args[1], ) -def test_list_notebook_runtime_templates_rest_flattened_error(transport: str = "rest"): +def test_get_notebook_runtime_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5686,124 +9233,106 @@ def test_list_notebook_runtime_templates_rest_flattened_error(transport: str = " # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_notebook_runtime_templates( - notebook_service.ListNotebookRuntimeTemplatesRequest(), - parent="parent_value", - ) - - -def test_list_notebook_runtime_templates_rest_pager(transport: str = "rest"): - client = NotebookServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - # with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - notebook_service.ListNotebookRuntimeTemplatesResponse( - notebook_runtime_templates=[ - notebook_runtime.NotebookRuntimeTemplate(), - notebook_runtime.NotebookRuntimeTemplate(), - notebook_runtime.NotebookRuntimeTemplate(), - ], - next_page_token="abc", - ), - notebook_service.ListNotebookRuntimeTemplatesResponse( - notebook_runtime_templates=[], - next_page_token="def", - ), - notebook_service.ListNotebookRuntimeTemplatesResponse( - notebook_runtime_templates=[ - notebook_runtime.NotebookRuntimeTemplate(), - ], - next_page_token="ghi", - ), - notebook_service.ListNotebookRuntimeTemplatesResponse( - notebook_runtime_templates=[ - notebook_runtime.NotebookRuntimeTemplate(), - notebook_runtime.NotebookRuntimeTemplate(), - ], - ), - ) - # Two responses for two calls - response = response + response - - # Wrap the values into proper Response objs - response = tuple( - notebook_service.ListNotebookRuntimeTemplatesResponse.to_json(x) - for x in response - ) - return_values = tuple(Response() for i in response) - for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode("UTF-8") - return_val.status_code = 200 - req.side_effect = return_values - - sample_request = {"parent": "projects/sample1/locations/sample2"} - - pager = client.list_notebook_runtime_templates(request=sample_request) - - results = list(pager) - assert len(results) == 6 - assert all( - isinstance(i, notebook_runtime.NotebookRuntimeTemplate) for i in results + client.get_notebook_runtime( + notebook_service.GetNotebookRuntimeRequest(), + name="name_value", ) - pages = list( - client.list_notebook_runtime_templates(request=sample_request).pages - ) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + +def test_get_notebook_runtime_rest_error(): + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) @pytest.mark.parametrize( "request_type", [ - notebook_service.DeleteNotebookRuntimeTemplateRequest, + notebook_service.ListNotebookRuntimesRequest, dict, ], ) -def test_delete_notebook_runtime_template_rest(request_type): +def test_list_notebook_runtimes_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" - } + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_service.ListNotebookRuntimesResponse( + next_page_token="next_page_token_value", + ) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookRuntimesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.delete_notebook_runtime_template(request) + response = client.list_notebook_runtimes(request) # Establish that the response is the type that we expect. - assert response.operation.name == "operations/spam" + assert isinstance(response, pagers.ListNotebookRuntimesPager) + assert response.next_page_token == "next_page_token_value" -def test_delete_notebook_runtime_template_rest_required_fields( - request_type=notebook_service.DeleteNotebookRuntimeTemplateRequest, +def test_list_notebook_runtimes_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_runtimes + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_runtimes + ] = mock_rpc + + request = {} + client.list_notebook_runtimes(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_runtimes(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_notebook_runtimes_rest_required_fields( + request_type=notebook_service.ListNotebookRuntimesRequest, ): transport_class = transports.NotebookServiceRestTransport request_init = {} - request_init["name"] = "" + request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -5814,21 +9343,31 @@ def test_delete_notebook_runtime_template_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_notebook_runtime_template._get_unset_required_fields(jsonified_request) + ).list_notebook_runtimes._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = "name_value" + jsonified_request["parent"] = "parent_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_notebook_runtime_template._get_unset_required_fields(jsonified_request) + ).list_notebook_runtimes._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "filter", + "order_by", + "page_size", + "page_token", + "read_mask", + ) + ) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "name" in jsonified_request - assert jsonified_request["name"] == "name_value" + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5837,7 +9376,7 @@ def test_delete_notebook_runtime_template_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_service.ListNotebookRuntimesResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5849,38 +9388,52 @@ def test_delete_notebook_runtime_template_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "delete", + "method": "get", "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookRuntimesResponse.pb( + return_value + ) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.delete_notebook_runtime_template(request) + response = client.list_notebook_runtimes(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_delete_notebook_runtime_template_rest_unset_required_fields(): +def test_list_notebook_runtimes_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = ( - transport.delete_notebook_runtime_template._get_unset_required_fields({}) + unset_fields = transport.list_notebook_runtimes._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "filter", + "orderBy", + "pageSize", + "pageToken", + "readMask", + ) + ) + & set(("parent",)) ) - assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_delete_notebook_runtime_template_rest_interceptors(null_interceptor): +def test_list_notebook_runtimes_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -5893,18 +9446,14 @@ def test_delete_notebook_runtime_template_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.NotebookServiceRestInterceptor, - "post_delete_notebook_runtime_template", + transports.NotebookServiceRestInterceptor, "post_list_notebook_runtimes" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, - "pre_delete_notebook_runtime_template", + transports.NotebookServiceRestInterceptor, "pre_list_notebook_runtimes" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.DeleteNotebookRuntimeTemplateRequest.pb( - notebook_service.DeleteNotebookRuntimeTemplateRequest() + pb_message = notebook_service.ListNotebookRuntimesRequest.pb( + notebook_service.ListNotebookRuntimesRequest() ) transcode.return_value = { "method": "post", @@ -5916,19 +9465,21 @@ def test_delete_notebook_runtime_template_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = json_format.MessageToJson( - operations_pb2.Operation() + req.return_value._content = ( + notebook_service.ListNotebookRuntimesResponse.to_json( + notebook_service.ListNotebookRuntimesResponse() + ) ) - request = notebook_service.DeleteNotebookRuntimeTemplateRequest() + request = notebook_service.ListNotebookRuntimesRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() + post.return_value = notebook_service.ListNotebookRuntimesResponse() - client.delete_notebook_runtime_template( + client.list_notebook_runtimes( request, metadata=[ ("key", "val"), @@ -5940,9 +9491,8 @@ def test_delete_notebook_runtime_template_rest_interceptors(null_interceptor): post.assert_called_once() -def test_delete_notebook_runtime_template_rest_bad_request( - transport: str = "rest", - request_type=notebook_service.DeleteNotebookRuntimeTemplateRequest, +def test_list_notebook_runtimes_rest_bad_request( + transport: str = "rest", request_type=notebook_service.ListNotebookRuntimesRequest ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5950,9 +9500,7 @@ def test_delete_notebook_runtime_template_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" - } + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -5964,10 +9512,10 @@ def test_delete_notebook_runtime_template_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.delete_notebook_runtime_template(request) + client.list_notebook_runtimes(request) -def test_delete_notebook_runtime_template_rest_flattened(): +def test_list_notebook_runtimes_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5976,40 +9524,40 @@ def test_delete_notebook_runtime_template_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_service.ListNotebookRuntimesResponse() # get arguments that satisfy an http rule for this method - sample_request = { - "name": "projects/sample1/locations/sample2/notebookRuntimeTemplates/sample3" - } + sample_request = {"parent": "projects/sample1/locations/sample2"} # get truthy value for each flattened field mock_args = dict( - name="name_value", + parent="parent_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookRuntimesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.delete_notebook_runtime_template(**mock_args) + client.list_notebook_runtimes(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimeTemplates/*}" + "%s/v1beta1/{parent=projects/*/locations/*}/notebookRuntimes" % client.transport._host, args[1], ) -def test_delete_notebook_runtime_template_rest_flattened_error(transport: str = "rest"): +def test_list_notebook_runtimes_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6018,33 +9566,92 @@ def test_delete_notebook_runtime_template_rest_flattened_error(transport: str = # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_notebook_runtime_template( - notebook_service.DeleteNotebookRuntimeTemplateRequest(), - name="name_value", + client.list_notebook_runtimes( + notebook_service.ListNotebookRuntimesRequest(), + parent="parent_value", ) -def test_delete_notebook_runtime_template_rest_error(): +def test_list_notebook_runtimes_rest_pager(transport: str = "rest"): client = NotebookServiceClient( - credentials=ga_credentials.AnonymousCredentials(), transport="rest" + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + notebook_service.ListNotebookRuntimesResponse( + notebook_runtimes=[ + notebook_runtime.NotebookRuntime(), + notebook_runtime.NotebookRuntime(), + notebook_runtime.NotebookRuntime(), + ], + next_page_token="abc", + ), + notebook_service.ListNotebookRuntimesResponse( + notebook_runtimes=[], + next_page_token="def", + ), + notebook_service.ListNotebookRuntimesResponse( + notebook_runtimes=[ + notebook_runtime.NotebookRuntime(), + ], + next_page_token="ghi", + ), + notebook_service.ListNotebookRuntimesResponse( + notebook_runtimes=[ + notebook_runtime.NotebookRuntime(), + notebook_runtime.NotebookRuntime(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + notebook_service.ListNotebookRuntimesResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/locations/sample2"} + + pager = client.list_notebook_runtimes(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, notebook_runtime.NotebookRuntime) for i in results) + + pages = list(client.list_notebook_runtimes(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + @pytest.mark.parametrize( "request_type", [ - notebook_service.AssignNotebookRuntimeRequest, + notebook_service.DeleteNotebookRuntimeRequest, dict, ], ) -def test_assign_notebook_runtime_rest(request_type): +def test_delete_notebook_runtime_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -6059,20 +9666,64 @@ def test_assign_notebook_runtime_rest(request_type): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.assign_notebook_runtime(request) + response = client.delete_notebook_runtime(request) # Establish that the response is the type that we expect. assert response.operation.name == "operations/spam" -def test_assign_notebook_runtime_rest_required_fields( - request_type=notebook_service.AssignNotebookRuntimeRequest, +def test_delete_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_runtime + ] = mock_rpc + + request = {} + client.delete_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_notebook_runtime_rest_required_fields( + request_type=notebook_service.DeleteNotebookRuntimeRequest, ): transport_class = transports.NotebookServiceRestTransport request_init = {} - request_init["parent"] = "" - request_init["notebook_runtime_template"] = "" + request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -6083,27 +9734,21 @@ def test_assign_notebook_runtime_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).assign_notebook_runtime._get_unset_required_fields(jsonified_request) + ).delete_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" - jsonified_request["notebookRuntimeTemplate"] = "notebook_runtime_template_value" + jsonified_request["name"] = "name_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).assign_notebook_runtime._get_unset_required_fields(jsonified_request) + ).delete_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" - assert "notebookRuntimeTemplate" in jsonified_request - assert ( - jsonified_request["notebookRuntimeTemplate"] - == "notebook_runtime_template_value" - ) + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6124,10 +9769,9 @@ def test_assign_notebook_runtime_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "delete", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -6137,33 +9781,24 @@ def test_assign_notebook_runtime_rest_required_fields( response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.assign_notebook_runtime(request) + response = client.delete_notebook_runtime(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_assign_notebook_runtime_rest_unset_required_fields(): +def test_delete_notebook_runtime_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.assign_notebook_runtime._get_unset_required_fields({}) - assert set(unset_fields) == ( - set(()) - & set( - ( - "parent", - "notebookRuntimeTemplate", - "notebookRuntime", - ) - ) - ) + unset_fields = transport.delete_notebook_runtime._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_assign_notebook_runtime_rest_interceptors(null_interceptor): +def test_delete_notebook_runtime_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -6178,14 +9813,14 @@ def test_assign_notebook_runtime_rest_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.NotebookServiceRestInterceptor, "post_assign_notebook_runtime" + transports.NotebookServiceRestInterceptor, "post_delete_notebook_runtime" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, "pre_assign_notebook_runtime" + transports.NotebookServiceRestInterceptor, "pre_delete_notebook_runtime" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.AssignNotebookRuntimeRequest.pb( - notebook_service.AssignNotebookRuntimeRequest() + pb_message = notebook_service.DeleteNotebookRuntimeRequest.pb( + notebook_service.DeleteNotebookRuntimeRequest() ) transcode.return_value = { "method": "post", @@ -6201,7 +9836,7 @@ def test_assign_notebook_runtime_rest_interceptors(null_interceptor): operations_pb2.Operation() ) - request = notebook_service.AssignNotebookRuntimeRequest() + request = notebook_service.DeleteNotebookRuntimeRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -6209,7 +9844,7 @@ def test_assign_notebook_runtime_rest_interceptors(null_interceptor): pre.return_value = request, metadata post.return_value = operations_pb2.Operation() - client.assign_notebook_runtime( + client.delete_notebook_runtime( request, metadata=[ ("key", "val"), @@ -6221,8 +9856,8 @@ def test_assign_notebook_runtime_rest_interceptors(null_interceptor): post.assert_called_once() -def test_assign_notebook_runtime_rest_bad_request( - transport: str = "rest", request_type=notebook_service.AssignNotebookRuntimeRequest +def test_delete_notebook_runtime_rest_bad_request( + transport: str = "rest", request_type=notebook_service.DeleteNotebookRuntimeRequest ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6230,7 +9865,9 @@ def test_assign_notebook_runtime_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -6242,10 +9879,10 @@ def test_assign_notebook_runtime_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.assign_notebook_runtime(request) + client.delete_notebook_runtime(request) -def test_assign_notebook_runtime_rest_flattened(): +def test_delete_notebook_runtime_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6257,14 +9894,13 @@ def test_assign_notebook_runtime_rest_flattened(): return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", - notebook_runtime_template="notebook_runtime_template_value", - notebook_runtime=gca_notebook_runtime.NotebookRuntime(name="name_value"), - notebook_runtime_id="notebook_runtime_id_value", + name="name_value", ) mock_args.update(sample_request) @@ -6275,20 +9911,20 @@ def test_assign_notebook_runtime_rest_flattened(): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.assign_notebook_runtime(**mock_args) + client.delete_notebook_runtime(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/notebookRuntimes:assign" + "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimes/*}" % client.transport._host, args[1], ) -def test_assign_notebook_runtime_rest_flattened_error(transport: str = "rest"): +def test_delete_notebook_runtime_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6297,16 +9933,13 @@ def test_assign_notebook_runtime_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.assign_notebook_runtime( - notebook_service.AssignNotebookRuntimeRequest(), - parent="parent_value", - notebook_runtime_template="notebook_runtime_template_value", - notebook_runtime=gca_notebook_runtime.NotebookRuntime(name="name_value"), - notebook_runtime_id="notebook_runtime_id_value", + client.delete_notebook_runtime( + notebook_service.DeleteNotebookRuntimeRequest(), + name="name_value", ) -def test_assign_notebook_runtime_rest_error(): +def test_delete_notebook_runtime_rest_error(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -6315,11 +9948,11 @@ def test_assign_notebook_runtime_rest_error(): @pytest.mark.parametrize( "request_type", [ - notebook_service.GetNotebookRuntimeRequest, + notebook_service.UpgradeNotebookRuntimeRequest, dict, ], ) -def test_get_notebook_runtime_rest(request_type): +def test_upgrade_notebook_runtime_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6334,55 +9967,68 @@ def test_get_notebook_runtime_rest(request_type): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = notebook_runtime.NotebookRuntime( - name="name_value", - runtime_user="runtime_user_value", - proxy_uri="proxy_uri_value", - health_state=notebook_runtime.NotebookRuntime.HealthState.HEALTHY, - display_name="display_name_value", - description="description_value", - service_account="service_account_value", - runtime_state=notebook_runtime.NotebookRuntime.RuntimeState.RUNNING, - is_upgradable=True, - version="version_value", - notebook_runtime_type=notebook_runtime.NotebookRuntimeType.USER_DEFINED, - network_tags=["network_tags_value"], - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = notebook_runtime.NotebookRuntime.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.get_notebook_runtime(request) + response = client.upgrade_notebook_runtime(request) # Establish that the response is the type that we expect. - assert isinstance(response, notebook_runtime.NotebookRuntime) - assert response.name == "name_value" - assert response.runtime_user == "runtime_user_value" - assert response.proxy_uri == "proxy_uri_value" - assert response.health_state == notebook_runtime.NotebookRuntime.HealthState.HEALTHY - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.service_account == "service_account_value" - assert ( - response.runtime_state == notebook_runtime.NotebookRuntime.RuntimeState.RUNNING - ) - assert response.is_upgradable is True - assert response.version == "version_value" - assert ( - response.notebook_runtime_type - == notebook_runtime.NotebookRuntimeType.USER_DEFINED - ) - assert response.network_tags == ["network_tags_value"] + assert response.operation.name == "operations/spam" -def test_get_notebook_runtime_rest_required_fields( - request_type=notebook_service.GetNotebookRuntimeRequest, +def test_upgrade_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.upgrade_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.upgrade_notebook_runtime + ] = mock_rpc + + request = {} + client.upgrade_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.upgrade_notebook_runtime(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_upgrade_notebook_runtime_rest_required_fields( + request_type=notebook_service.UpgradeNotebookRuntimeRequest, ): transport_class = transports.NotebookServiceRestTransport @@ -6398,7 +10044,7 @@ def test_get_notebook_runtime_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_notebook_runtime._get_unset_required_fields(jsonified_request) + ).upgrade_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -6407,7 +10053,7 @@ def test_get_notebook_runtime_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_notebook_runtime._get_unset_required_fields(jsonified_request) + ).upgrade_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -6421,7 +10067,7 @@ def test_get_notebook_runtime_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = notebook_runtime.NotebookRuntime() + return_value = operations_pb2.Operation(name="operations/spam") # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -6433,39 +10079,37 @@ def test_get_notebook_runtime_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "post", "query_params": pb_request, } + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = notebook_runtime.NotebookRuntime.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.get_notebook_runtime(request) + response = client.upgrade_notebook_runtime(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_get_notebook_runtime_rest_unset_required_fields(): +def test_upgrade_notebook_runtime_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.get_notebook_runtime._get_unset_required_fields({}) + unset_fields = transport.upgrade_notebook_runtime._get_unset_required_fields({}) assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_get_notebook_runtime_rest_interceptors(null_interceptor): +def test_upgrade_notebook_runtime_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -6478,14 +10122,16 @@ def test_get_notebook_runtime_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.NotebookServiceRestInterceptor, "post_get_notebook_runtime" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.NotebookServiceRestInterceptor, "post_upgrade_notebook_runtime" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, "pre_get_notebook_runtime" + transports.NotebookServiceRestInterceptor, "pre_upgrade_notebook_runtime" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.GetNotebookRuntimeRequest.pb( - notebook_service.GetNotebookRuntimeRequest() + pb_message = notebook_service.UpgradeNotebookRuntimeRequest.pb( + notebook_service.UpgradeNotebookRuntimeRequest() ) transcode.return_value = { "method": "post", @@ -6497,19 +10143,19 @@ def test_get_notebook_runtime_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = notebook_runtime.NotebookRuntime.to_json( - notebook_runtime.NotebookRuntime() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() ) - request = notebook_service.GetNotebookRuntimeRequest() + request = notebook_service.UpgradeNotebookRuntimeRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = notebook_runtime.NotebookRuntime() + post.return_value = operations_pb2.Operation() - client.get_notebook_runtime( + client.upgrade_notebook_runtime( request, metadata=[ ("key", "val"), @@ -6521,8 +10167,8 @@ def test_get_notebook_runtime_rest_interceptors(null_interceptor): post.assert_called_once() -def test_get_notebook_runtime_rest_bad_request( - transport: str = "rest", request_type=notebook_service.GetNotebookRuntimeRequest +def test_upgrade_notebook_runtime_rest_bad_request( + transport: str = "rest", request_type=notebook_service.UpgradeNotebookRuntimeRequest ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6544,10 +10190,10 @@ def test_get_notebook_runtime_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.get_notebook_runtime(request) + client.upgrade_notebook_runtime(request) -def test_get_notebook_runtime_rest_flattened(): +def test_upgrade_notebook_runtime_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6556,7 +10202,7 @@ def test_get_notebook_runtime_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = notebook_runtime.NotebookRuntime() + return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method sample_request = { @@ -6572,26 +10218,24 @@ def test_get_notebook_runtime_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = notebook_runtime.NotebookRuntime.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.get_notebook_runtime(**mock_args) + client.upgrade_notebook_runtime(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimes/*}" + "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimes/*}:upgrade" % client.transport._host, args[1], ) -def test_get_notebook_runtime_rest_flattened_error(transport: str = "rest"): +def test_upgrade_notebook_runtime_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6600,13 +10244,13 @@ def test_get_notebook_runtime_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_notebook_runtime( - notebook_service.GetNotebookRuntimeRequest(), + client.upgrade_notebook_runtime( + notebook_service.UpgradeNotebookRuntimeRequest(), name="name_value", ) -def test_get_notebook_runtime_rest_error(): +def test_upgrade_notebook_runtime_rest_error(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -6615,50 +10259,92 @@ def test_get_notebook_runtime_rest_error(): @pytest.mark.parametrize( "request_type", [ - notebook_service.ListNotebookRuntimesRequest, + notebook_service.StartNotebookRuntimeRequest, dict, ], ) -def test_list_notebook_runtimes_rest(request_type): +def test_start_notebook_runtime_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = notebook_service.ListNotebookRuntimesResponse( - next_page_token="next_page_token_value", - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = notebook_service.ListNotebookRuntimesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.list_notebook_runtimes(request) + response = client.start_notebook_runtime(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_start_notebook_runtime_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.start_notebook_runtime + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.start_notebook_runtime + ] = mock_rpc + + request = {} + client.start_notebook_runtime(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.start_notebook_runtime(request) - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListNotebookRuntimesPager) - assert response.next_page_token == "next_page_token_value" + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 -def test_list_notebook_runtimes_rest_required_fields( - request_type=notebook_service.ListNotebookRuntimesRequest, +def test_start_notebook_runtime_rest_required_fields( + request_type=notebook_service.StartNotebookRuntimeRequest, ): transport_class = transports.NotebookServiceRestTransport request_init = {} - request_init["parent"] = "" + request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -6669,31 +10355,21 @@ def test_list_notebook_runtimes_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_notebook_runtimes._get_unset_required_fields(jsonified_request) + ).start_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["name"] = "name_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_notebook_runtimes._get_unset_required_fields(jsonified_request) - # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set( - ( - "filter", - "order_by", - "page_size", - "page_token", - "read_mask", - ) - ) + ).start_notebook_runtime._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6702,7 +10378,7 @@ def test_list_notebook_runtimes_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = notebook_service.ListNotebookRuntimesResponse() + return_value = operations_pb2.Operation(name="operations/spam") # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -6714,52 +10390,37 @@ def test_list_notebook_runtimes_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "post", "query_params": pb_request, } + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = notebook_service.ListNotebookRuntimesResponse.pb( - return_value - ) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.list_notebook_runtimes(request) + response = client.start_notebook_runtime(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_list_notebook_runtimes_rest_unset_required_fields(): +def test_start_notebook_runtime_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.list_notebook_runtimes._get_unset_required_fields({}) - assert set(unset_fields) == ( - set( - ( - "filter", - "orderBy", - "pageSize", - "pageToken", - "readMask", - ) - ) - & set(("parent",)) - ) + unset_fields = transport.start_notebook_runtime._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_list_notebook_runtimes_rest_interceptors(null_interceptor): +def test_start_notebook_runtime_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -6772,14 +10433,16 @@ def test_list_notebook_runtimes_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.NotebookServiceRestInterceptor, "post_list_notebook_runtimes" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.NotebookServiceRestInterceptor, "post_start_notebook_runtime" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, "pre_list_notebook_runtimes" + transports.NotebookServiceRestInterceptor, "pre_start_notebook_runtime" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.ListNotebookRuntimesRequest.pb( - notebook_service.ListNotebookRuntimesRequest() + pb_message = notebook_service.StartNotebookRuntimeRequest.pb( + notebook_service.StartNotebookRuntimeRequest() ) transcode.return_value = { "method": "post", @@ -6791,21 +10454,19 @@ def test_list_notebook_runtimes_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = ( - notebook_service.ListNotebookRuntimesResponse.to_json( - notebook_service.ListNotebookRuntimesResponse() - ) + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() ) - request = notebook_service.ListNotebookRuntimesRequest() + request = notebook_service.StartNotebookRuntimeRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = notebook_service.ListNotebookRuntimesResponse() + post.return_value = operations_pb2.Operation() - client.list_notebook_runtimes( + client.start_notebook_runtime( request, metadata=[ ("key", "val"), @@ -6817,8 +10478,8 @@ def test_list_notebook_runtimes_rest_interceptors(null_interceptor): post.assert_called_once() -def test_list_notebook_runtimes_rest_bad_request( - transport: str = "rest", request_type=notebook_service.ListNotebookRuntimesRequest +def test_start_notebook_runtime_rest_bad_request( + transport: str = "rest", request_type=notebook_service.StartNotebookRuntimeRequest ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6826,7 +10487,9 @@ def test_list_notebook_runtimes_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -6838,10 +10501,10 @@ def test_list_notebook_runtimes_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.list_notebook_runtimes(request) + client.start_notebook_runtime(request) -def test_list_notebook_runtimes_rest_flattened(): +def test_start_notebook_runtime_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6850,40 +10513,40 @@ def test_list_notebook_runtimes_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = notebook_service.ListNotebookRuntimesResponse() + return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", + name="name_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = notebook_service.ListNotebookRuntimesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.list_notebook_runtimes(**mock_args) + client.start_notebook_runtime(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/notebookRuntimes" + "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimes/*}:start" % client.transport._host, args[1], ) -def test_list_notebook_runtimes_rest_flattened_error(transport: str = "rest"): +def test_start_notebook_runtime_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6892,83 +10555,26 @@ def test_list_notebook_runtimes_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_notebook_runtimes( - notebook_service.ListNotebookRuntimesRequest(), - parent="parent_value", + client.start_notebook_runtime( + notebook_service.StartNotebookRuntimeRequest(), + name="name_value", ) -def test_list_notebook_runtimes_rest_pager(transport: str = "rest"): +def test_start_notebook_runtime_rest_error(): client = NotebookServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - # with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - notebook_service.ListNotebookRuntimesResponse( - notebook_runtimes=[ - notebook_runtime.NotebookRuntime(), - notebook_runtime.NotebookRuntime(), - notebook_runtime.NotebookRuntime(), - ], - next_page_token="abc", - ), - notebook_service.ListNotebookRuntimesResponse( - notebook_runtimes=[], - next_page_token="def", - ), - notebook_service.ListNotebookRuntimesResponse( - notebook_runtimes=[ - notebook_runtime.NotebookRuntime(), - ], - next_page_token="ghi", - ), - notebook_service.ListNotebookRuntimesResponse( - notebook_runtimes=[ - notebook_runtime.NotebookRuntime(), - notebook_runtime.NotebookRuntime(), - ], - ), - ) - # Two responses for two calls - response = response + response - - # Wrap the values into proper Response objs - response = tuple( - notebook_service.ListNotebookRuntimesResponse.to_json(x) for x in response - ) - return_values = tuple(Response() for i in response) - for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode("UTF-8") - return_val.status_code = 200 - req.side_effect = return_values - - sample_request = {"parent": "projects/sample1/locations/sample2"} - - pager = client.list_notebook_runtimes(request=sample_request) - - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, notebook_runtime.NotebookRuntime) for i in results) - - pages = list(client.list_notebook_runtimes(request=sample_request).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - @pytest.mark.parametrize( "request_type", [ - notebook_service.DeleteNotebookRuntimeRequest, + notebook_service.GetNotebookExecutionJobRequest, dict, ], ) -def test_delete_notebook_runtime_rest(request_type): +def test_get_notebook_execution_job_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6976,30 +10582,85 @@ def test_delete_notebook_runtime_rest(request_type): # send a request that will satisfy transcoding request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + "name": "projects/sample1/locations/sample2/notebookExecutionJobs/sample3" } request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_execution_job.NotebookExecutionJob( + name="name_value", + display_name="display_name_value", + schedule_resource_name="schedule_resource_name_value", + job_state=job_state.JobState.JOB_STATE_QUEUED, + notebook_runtime_template_resource_name="notebook_runtime_template_resource_name_value", + gcs_output_uri="gcs_output_uri_value", + execution_user="execution_user_value", + ) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_execution_job.NotebookExecutionJob.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.delete_notebook_runtime(request) + response = client.get_notebook_execution_job(request) # Establish that the response is the type that we expect. - assert response.operation.name == "operations/spam" + assert isinstance(response, notebook_execution_job.NotebookExecutionJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.schedule_resource_name == "schedule_resource_name_value" + assert response.job_state == job_state.JobState.JOB_STATE_QUEUED -def test_delete_notebook_runtime_rest_required_fields( - request_type=notebook_service.DeleteNotebookRuntimeRequest, +def test_get_notebook_execution_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_notebook_execution_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_notebook_execution_job + ] = mock_rpc + + request = {} + client.get_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_notebook_execution_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_notebook_execution_job_rest_required_fields( + request_type=notebook_service.GetNotebookExecutionJobRequest, ): transport_class = transports.NotebookServiceRestTransport @@ -7015,7 +10676,7 @@ def test_delete_notebook_runtime_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_notebook_runtime._get_unset_required_fields(jsonified_request) + ).get_notebook_execution_job._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -7024,7 +10685,9 @@ def test_delete_notebook_runtime_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_notebook_runtime._get_unset_required_fields(jsonified_request) + ).get_notebook_execution_job._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("view",)) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -7038,7 +10701,7 @@ def test_delete_notebook_runtime_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_execution_job.NotebookExecutionJob() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7050,36 +10713,39 @@ def test_delete_notebook_runtime_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "delete", + "method": "get", "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = notebook_execution_job.NotebookExecutionJob.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.delete_notebook_runtime(request) + response = client.get_notebook_execution_job(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_delete_notebook_runtime_rest_unset_required_fields(): +def test_get_notebook_execution_job_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.delete_notebook_runtime._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name",))) + unset_fields = transport.get_notebook_execution_job._get_unset_required_fields({}) + assert set(unset_fields) == (set(("view",)) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_delete_notebook_runtime_rest_interceptors(null_interceptor): +def test_get_notebook_execution_job_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -7092,16 +10758,14 @@ def test_delete_notebook_runtime_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.NotebookServiceRestInterceptor, "post_delete_notebook_runtime" + transports.NotebookServiceRestInterceptor, "post_get_notebook_execution_job" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, "pre_delete_notebook_runtime" + transports.NotebookServiceRestInterceptor, "pre_get_notebook_execution_job" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.DeleteNotebookRuntimeRequest.pb( - notebook_service.DeleteNotebookRuntimeRequest() + pb_message = notebook_service.GetNotebookExecutionJobRequest.pb( + notebook_service.GetNotebookExecutionJobRequest() ) transcode.return_value = { "method": "post", @@ -7113,19 +10777,19 @@ def test_delete_notebook_runtime_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = json_format.MessageToJson( - operations_pb2.Operation() + req.return_value._content = notebook_execution_job.NotebookExecutionJob.to_json( + notebook_execution_job.NotebookExecutionJob() ) - request = notebook_service.DeleteNotebookRuntimeRequest() + request = notebook_service.GetNotebookExecutionJobRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() + post.return_value = notebook_execution_job.NotebookExecutionJob() - client.delete_notebook_runtime( + client.get_notebook_execution_job( request, metadata=[ ("key", "val"), @@ -7137,8 +10801,9 @@ def test_delete_notebook_runtime_rest_interceptors(null_interceptor): post.assert_called_once() -def test_delete_notebook_runtime_rest_bad_request( - transport: str = "rest", request_type=notebook_service.DeleteNotebookRuntimeRequest +def test_get_notebook_execution_job_rest_bad_request( + transport: str = "rest", + request_type=notebook_service.GetNotebookExecutionJobRequest, ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7147,7 +10812,7 @@ def test_delete_notebook_runtime_rest_bad_request( # send a request that will satisfy transcoding request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + "name": "projects/sample1/locations/sample2/notebookExecutionJobs/sample3" } request = request_type(**request_init) @@ -7160,10 +10825,10 @@ def test_delete_notebook_runtime_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.delete_notebook_runtime(request) + client.get_notebook_execution_job(request) -def test_delete_notebook_runtime_rest_flattened(): +def test_get_notebook_execution_job_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -7172,11 +10837,11 @@ def test_delete_notebook_runtime_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_execution_job.NotebookExecutionJob() # get arguments that satisfy an http rule for this method sample_request = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + "name": "projects/sample1/locations/sample2/notebookExecutionJobs/sample3" } # get truthy value for each flattened field @@ -7188,24 +10853,26 @@ def test_delete_notebook_runtime_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_execution_job.NotebookExecutionJob.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.delete_notebook_runtime(**mock_args) + client.get_notebook_execution_job(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimes/*}" + "%s/v1beta1/{name=projects/*/locations/*/notebookExecutionJobs/*}" % client.transport._host, args[1], ) -def test_delete_notebook_runtime_rest_flattened_error(transport: str = "rest"): +def test_get_notebook_execution_job_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7214,13 +10881,13 @@ def test_delete_notebook_runtime_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_notebook_runtime( - notebook_service.DeleteNotebookRuntimeRequest(), + client.get_notebook_execution_job( + notebook_service.GetNotebookExecutionJobRequest(), name="name_value", ) -def test_delete_notebook_runtime_rest_error(): +def test_get_notebook_execution_job_rest_error(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -7229,47 +10896,93 @@ def test_delete_notebook_runtime_rest_error(): @pytest.mark.parametrize( "request_type", [ - notebook_service.UpgradeNotebookRuntimeRequest, + notebook_service.ListNotebookExecutionJobsRequest, dict, ], ) -def test_upgrade_notebook_runtime_rest(request_type): +def test_list_notebook_execution_jobs_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" - } + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_service.ListNotebookExecutionJobsResponse( + next_page_token="next_page_token_value", + ) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookExecutionJobsResponse.pb( + return_value + ) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.upgrade_notebook_runtime(request) + response = client.list_notebook_execution_jobs(request) # Establish that the response is the type that we expect. - assert response.operation.name == "operations/spam" + assert isinstance(response, pagers.ListNotebookExecutionJobsPager) + assert response.next_page_token == "next_page_token_value" -def test_upgrade_notebook_runtime_rest_required_fields( - request_type=notebook_service.UpgradeNotebookRuntimeRequest, +def test_list_notebook_execution_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_notebook_execution_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_notebook_execution_jobs + ] = mock_rpc + + request = {} + client.list_notebook_execution_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_notebook_execution_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_notebook_execution_jobs_rest_required_fields( + request_type=notebook_service.ListNotebookExecutionJobsRequest, ): transport_class = transports.NotebookServiceRestTransport request_init = {} - request_init["name"] = "" + request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -7280,21 +10993,31 @@ def test_upgrade_notebook_runtime_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).upgrade_notebook_runtime._get_unset_required_fields(jsonified_request) + ).list_notebook_execution_jobs._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = "name_value" + jsonified_request["parent"] = "parent_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).upgrade_notebook_runtime._get_unset_required_fields(jsonified_request) + ).list_notebook_execution_jobs._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "filter", + "order_by", + "page_size", + "page_token", + "view", + ) + ) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "name" in jsonified_request - assert jsonified_request["name"] == "name_value" + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7303,7 +11026,7 @@ def test_upgrade_notebook_runtime_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_service.ListNotebookExecutionJobsResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7315,37 +11038,52 @@ def test_upgrade_notebook_runtime_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "get", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookExecutionJobsResponse.pb( + return_value + ) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.upgrade_notebook_runtime(request) + response = client.list_notebook_execution_jobs(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_upgrade_notebook_runtime_rest_unset_required_fields(): +def test_list_notebook_execution_jobs_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.upgrade_notebook_runtime._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name",))) + unset_fields = transport.list_notebook_execution_jobs._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "filter", + "orderBy", + "pageSize", + "pageToken", + "view", + ) + ) + & set(("parent",)) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_upgrade_notebook_runtime_rest_interceptors(null_interceptor): +def test_list_notebook_execution_jobs_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -7358,16 +11096,14 @@ def test_upgrade_notebook_runtime_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.NotebookServiceRestInterceptor, "post_upgrade_notebook_runtime" + transports.NotebookServiceRestInterceptor, "post_list_notebook_execution_jobs" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, "pre_upgrade_notebook_runtime" + transports.NotebookServiceRestInterceptor, "pre_list_notebook_execution_jobs" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.UpgradeNotebookRuntimeRequest.pb( - notebook_service.UpgradeNotebookRuntimeRequest() + pb_message = notebook_service.ListNotebookExecutionJobsRequest.pb( + notebook_service.ListNotebookExecutionJobsRequest() ) transcode.return_value = { "method": "post", @@ -7379,19 +11115,21 @@ def test_upgrade_notebook_runtime_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = json_format.MessageToJson( - operations_pb2.Operation() + req.return_value._content = ( + notebook_service.ListNotebookExecutionJobsResponse.to_json( + notebook_service.ListNotebookExecutionJobsResponse() + ) ) - request = notebook_service.UpgradeNotebookRuntimeRequest() + request = notebook_service.ListNotebookExecutionJobsRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() + post.return_value = notebook_service.ListNotebookExecutionJobsResponse() - client.upgrade_notebook_runtime( + client.list_notebook_execution_jobs( request, metadata=[ ("key", "val"), @@ -7403,8 +11141,9 @@ def test_upgrade_notebook_runtime_rest_interceptors(null_interceptor): post.assert_called_once() -def test_upgrade_notebook_runtime_rest_bad_request( - transport: str = "rest", request_type=notebook_service.UpgradeNotebookRuntimeRequest +def test_list_notebook_execution_jobs_rest_bad_request( + transport: str = "rest", + request_type=notebook_service.ListNotebookExecutionJobsRequest, ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7412,9 +11151,7 @@ def test_upgrade_notebook_runtime_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" - } + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -7426,10 +11163,10 @@ def test_upgrade_notebook_runtime_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.upgrade_notebook_runtime(request) + client.list_notebook_execution_jobs(request) -def test_upgrade_notebook_runtime_rest_flattened(): +def test_list_notebook_execution_jobs_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -7438,40 +11175,42 @@ def test_upgrade_notebook_runtime_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = notebook_service.ListNotebookExecutionJobsResponse() # get arguments that satisfy an http rule for this method - sample_request = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" - } + sample_request = {"parent": "projects/sample1/locations/sample2"} # get truthy value for each flattened field mock_args = dict( - name="name_value", + parent="parent_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = notebook_service.ListNotebookExecutionJobsResponse.pb( + return_value + ) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.upgrade_notebook_runtime(**mock_args) + client.list_notebook_execution_jobs(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimes/*}:upgrade" + "%s/v1beta1/{parent=projects/*/locations/*}/notebookExecutionJobs" % client.transport._host, args[1], ) -def test_upgrade_notebook_runtime_rest_flattened_error(transport: str = "rest"): +def test_list_notebook_execution_jobs_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7480,26 +11219,86 @@ def test_upgrade_notebook_runtime_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.upgrade_notebook_runtime( - notebook_service.UpgradeNotebookRuntimeRequest(), - name="name_value", + client.list_notebook_execution_jobs( + notebook_service.ListNotebookExecutionJobsRequest(), + parent="parent_value", ) -def test_upgrade_notebook_runtime_rest_error(): +def test_list_notebook_execution_jobs_rest_pager(transport: str = "rest"): client = NotebookServiceClient( - credentials=ga_credentials.AnonymousCredentials(), transport="rest" + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="abc", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[], + next_page_token="def", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + ], + next_page_token="ghi", + ), + notebook_service.ListNotebookExecutionJobsResponse( + notebook_execution_jobs=[ + notebook_execution_job.NotebookExecutionJob(), + notebook_execution_job.NotebookExecutionJob(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + notebook_service.ListNotebookExecutionJobsResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/locations/sample2"} + + pager = client.list_notebook_execution_jobs(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all( + isinstance(i, notebook_execution_job.NotebookExecutionJob) for i in results + ) + + pages = list(client.list_notebook_execution_jobs(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + @pytest.mark.parametrize( "request_type", [ - notebook_service.StartNotebookRuntimeRequest, + notebook_service.DeleteNotebookExecutionJobRequest, dict, ], ) -def test_start_notebook_runtime_rest(request_type): +def test_delete_notebook_execution_job_rest(request_type): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -7507,7 +11306,7 @@ def test_start_notebook_runtime_rest(request_type): # send a request that will satisfy transcoding request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + "name": "projects/sample1/locations/sample2/notebookExecutionJobs/sample3" } request = request_type(**request_init) @@ -7523,14 +11322,59 @@ def test_start_notebook_runtime_rest(request_type): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.start_notebook_runtime(request) + response = client.delete_notebook_execution_job(request) # Establish that the response is the type that we expect. assert response.operation.name == "operations/spam" -def test_start_notebook_runtime_rest_required_fields( - request_type=notebook_service.StartNotebookRuntimeRequest, +def test_delete_notebook_execution_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = NotebookServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_notebook_execution_job + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_notebook_execution_job + ] = mock_rpc + + request = {} + client.delete_notebook_execution_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_notebook_execution_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_notebook_execution_job_rest_required_fields( + request_type=notebook_service.DeleteNotebookExecutionJobRequest, ): transport_class = transports.NotebookServiceRestTransport @@ -7546,7 +11390,7 @@ def test_start_notebook_runtime_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).start_notebook_runtime._get_unset_required_fields(jsonified_request) + ).delete_notebook_execution_job._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -7555,7 +11399,7 @@ def test_start_notebook_runtime_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).start_notebook_runtime._get_unset_required_fields(jsonified_request) + ).delete_notebook_execution_job._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -7581,10 +11425,9 @@ def test_start_notebook_runtime_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "delete", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -7594,24 +11437,26 @@ def test_start_notebook_runtime_rest_required_fields( response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.start_notebook_runtime(request) + response = client.delete_notebook_execution_job(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_start_notebook_runtime_rest_unset_required_fields(): +def test_delete_notebook_execution_job_rest_unset_required_fields(): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.start_notebook_runtime._get_unset_required_fields({}) + unset_fields = transport.delete_notebook_execution_job._get_unset_required_fields( + {} + ) assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_start_notebook_runtime_rest_interceptors(null_interceptor): +def test_delete_notebook_execution_job_rest_interceptors(null_interceptor): transport = transports.NotebookServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -7626,14 +11471,14 @@ def test_start_notebook_runtime_rest_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.NotebookServiceRestInterceptor, "post_start_notebook_runtime" + transports.NotebookServiceRestInterceptor, "post_delete_notebook_execution_job" ) as post, mock.patch.object( - transports.NotebookServiceRestInterceptor, "pre_start_notebook_runtime" + transports.NotebookServiceRestInterceptor, "pre_delete_notebook_execution_job" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = notebook_service.StartNotebookRuntimeRequest.pb( - notebook_service.StartNotebookRuntimeRequest() + pb_message = notebook_service.DeleteNotebookExecutionJobRequest.pb( + notebook_service.DeleteNotebookExecutionJobRequest() ) transcode.return_value = { "method": "post", @@ -7649,7 +11494,7 @@ def test_start_notebook_runtime_rest_interceptors(null_interceptor): operations_pb2.Operation() ) - request = notebook_service.StartNotebookRuntimeRequest() + request = notebook_service.DeleteNotebookExecutionJobRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -7657,7 +11502,7 @@ def test_start_notebook_runtime_rest_interceptors(null_interceptor): pre.return_value = request, metadata post.return_value = operations_pb2.Operation() - client.start_notebook_runtime( + client.delete_notebook_execution_job( request, metadata=[ ("key", "val"), @@ -7669,8 +11514,9 @@ def test_start_notebook_runtime_rest_interceptors(null_interceptor): post.assert_called_once() -def test_start_notebook_runtime_rest_bad_request( - transport: str = "rest", request_type=notebook_service.StartNotebookRuntimeRequest +def test_delete_notebook_execution_job_rest_bad_request( + transport: str = "rest", + request_type=notebook_service.DeleteNotebookExecutionJobRequest, ): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7679,7 +11525,7 @@ def test_start_notebook_runtime_rest_bad_request( # send a request that will satisfy transcoding request_init = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + "name": "projects/sample1/locations/sample2/notebookExecutionJobs/sample3" } request = request_type(**request_init) @@ -7692,10 +11538,10 @@ def test_start_notebook_runtime_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.start_notebook_runtime(request) + client.delete_notebook_execution_job(request) -def test_start_notebook_runtime_rest_flattened(): +def test_delete_notebook_execution_job_rest_flattened(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -7708,7 +11554,7 @@ def test_start_notebook_runtime_rest_flattened(): # get arguments that satisfy an http rule for this method sample_request = { - "name": "projects/sample1/locations/sample2/notebookRuntimes/sample3" + "name": "projects/sample1/locations/sample2/notebookExecutionJobs/sample3" } # get truthy value for each flattened field @@ -7724,20 +11570,20 @@ def test_start_notebook_runtime_rest_flattened(): response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.start_notebook_runtime(**mock_args) + client.delete_notebook_execution_job(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/notebookRuntimes/*}:start" + "%s/v1beta1/{name=projects/*/locations/*/notebookExecutionJobs/*}" % client.transport._host, args[1], ) -def test_start_notebook_runtime_rest_flattened_error(transport: str = "rest"): +def test_delete_notebook_execution_job_rest_flattened_error(transport: str = "rest"): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7746,13 +11592,13 @@ def test_start_notebook_runtime_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.start_notebook_runtime( - notebook_service.StartNotebookRuntimeRequest(), + client.delete_notebook_execution_job( + notebook_service.DeleteNotebookExecutionJobRequest(), name="name_value", ) -def test_start_notebook_runtime_rest_error(): +def test_delete_notebook_execution_job_rest_error(): client = NotebookServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -7907,6 +11753,9 @@ def test_notebook_service_base_transport(): "delete_notebook_runtime", "upgrade_notebook_runtime", "start_notebook_runtime", + "get_notebook_execution_job", + "list_notebook_execution_jobs", + "delete_notebook_execution_job", "set_iam_policy", "get_iam_policy", "test_iam_permissions", @@ -8227,6 +12076,15 @@ def test_notebook_service_client_transport_session_collision(transport_name): session1 = client1.transport.start_notebook_runtime._session session2 = client2.transport.start_notebook_runtime._session assert session1 != session2 + session1 = client1.transport.get_notebook_execution_job._session + session2 = client2.transport.get_notebook_execution_job._session + assert session1 != session2 + session1 = client1.transport.list_notebook_execution_jobs._session + session2 = client2.transport.list_notebook_execution_jobs._session + assert session1 != session2 + session1 = client1.transport.delete_notebook_execution_job._session + session2 = client2.transport.delete_notebook_execution_job._session + assert session1 != session2 def test_notebook_service_grpc_transport_channel(): @@ -8412,10 +12270,38 @@ def test_parse_network_path(): assert expected == actual -def test_notebook_runtime_path(): +def test_notebook_execution_job_path(): project = "oyster" location = "nudibranch" - notebook_runtime = "cuttlefish" + notebook_execution_job = "cuttlefish" + expected = "projects/{project}/locations/{location}/notebookExecutionJobs/{notebook_execution_job}".format( + project=project, + location=location, + notebook_execution_job=notebook_execution_job, + ) + actual = NotebookServiceClient.notebook_execution_job_path( + project, location, notebook_execution_job + ) + assert expected == actual + + +def test_parse_notebook_execution_job_path(): + expected = { + "project": "mussel", + "location": "winkle", + "notebook_execution_job": "nautilus", + } + path = NotebookServiceClient.notebook_execution_job_path(**expected) + + # Check that the path construction is reversible. + actual = NotebookServiceClient.parse_notebook_execution_job_path(path) + assert expected == actual + + +def test_notebook_runtime_path(): + project = "scallop" + location = "abalone" + notebook_runtime = "squid" expected = "projects/{project}/locations/{location}/notebookRuntimes/{notebook_runtime}".format( project=project, location=location, @@ -8429,9 +12315,9 @@ def test_notebook_runtime_path(): def test_parse_notebook_runtime_path(): expected = { - "project": "mussel", - "location": "winkle", - "notebook_runtime": "nautilus", + "project": "clam", + "location": "whelk", + "notebook_runtime": "octopus", } path = NotebookServiceClient.notebook_runtime_path(**expected) @@ -8441,9 +12327,9 @@ def test_parse_notebook_runtime_path(): def test_notebook_runtime_template_path(): - project = "scallop" - location = "abalone" - notebook_runtime_template = "squid" + project = "oyster" + location = "nudibranch" + notebook_runtime_template = "cuttlefish" expected = "projects/{project}/locations/{location}/notebookRuntimeTemplates/{notebook_runtime_template}".format( project=project, location=location, @@ -8457,9 +12343,9 @@ def test_notebook_runtime_template_path(): def test_parse_notebook_runtime_template_path(): expected = { - "project": "clam", - "location": "whelk", - "notebook_runtime_template": "octopus", + "project": "mussel", + "location": "winkle", + "notebook_runtime_template": "nautilus", } path = NotebookServiceClient.notebook_runtime_template_path(**expected) @@ -8468,6 +12354,32 @@ def test_parse_notebook_runtime_template_path(): assert expected == actual +def test_schedule_path(): + project = "scallop" + location = "abalone" + schedule = "squid" + expected = "projects/{project}/locations/{location}/schedules/{schedule}".format( + project=project, + location=location, + schedule=schedule, + ) + actual = NotebookServiceClient.schedule_path(project, location, schedule) + assert expected == actual + + +def test_parse_schedule_path(): + expected = { + "project": "clam", + "location": "whelk", + "schedule": "octopus", + } + path = NotebookServiceClient.schedule_path(**expected) + + # Check that the path construction is reversible. + actual = NotebookServiceClient.parse_schedule_path(path) + assert expected == actual + + def test_subnetwork_path(): project = "oyster" region = "nudibranch" diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py index 51044be81c..b319e7c8a7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py @@ -1278,6 +1278,9 @@ def test_create_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.create_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1304,6 +1307,9 @@ def test_create_persistent_resource_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.create_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1313,6 +1319,50 @@ def test_create_persistent_resource_non_empty_request_with_auto_populated_field( ) +def test_create_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_persistent_resource + ] = mock_rpc + request = {} + client.create_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1336,6 +1386,56 @@ async def test_create_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.CreatePersistentResourceRequest() +@pytest.mark.asyncio +async def test_create_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_persistent_resource + ] = mock_object + + request = {} + await client.create_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_persistent_resource_async( transport: str = "grpc_asyncio", @@ -1612,6 +1712,9 @@ def test_get_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.get_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1637,6 +1740,9 @@ def test_get_persistent_resource_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1645,6 +1751,46 @@ def test_get_persistent_resource_non_empty_request_with_auto_populated_field(): ) +def test_get_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_persistent_resource + ] = mock_rpc + request = {} + client.get_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1674,6 +1820,52 @@ async def test_get_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.GetPersistentResourceRequest() +@pytest.mark.asyncio +async def test_get_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_persistent_resource + ] = mock_object + + request = {} + await client.get_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_persistent_resource_async( transport: str = "grpc_asyncio", @@ -1925,6 +2117,9 @@ def test_list_persistent_resources_empty_call(): with mock.patch.object( type(client.transport.list_persistent_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_persistent_resources() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1951,6 +2146,9 @@ def test_list_persistent_resources_non_empty_request_with_auto_populated_field() with mock.patch.object( type(client.transport.list_persistent_resources), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_persistent_resources(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1960,6 +2158,46 @@ def test_list_persistent_resources_non_empty_request_with_auto_populated_field() ) +def test_list_persistent_resources_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_persistent_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_persistent_resources + ] = mock_rpc + request = {} + client.list_persistent_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_persistent_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_persistent_resources_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1985,6 +2223,52 @@ async def test_list_persistent_resources_empty_call_async(): assert args[0] == persistent_resource_service.ListPersistentResourcesRequest() +@pytest.mark.asyncio +async def test_list_persistent_resources_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_persistent_resources + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_persistent_resources + ] = mock_object + + request = {} + await client.list_persistent_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_persistent_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_persistent_resources_async( transport: str = "grpc_asyncio", @@ -2433,6 +2717,9 @@ def test_delete_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.delete_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2458,6 +2745,9 @@ def test_delete_persistent_resource_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.delete_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2466,6 +2756,50 @@ def test_delete_persistent_resource_non_empty_request_with_auto_populated_field( ) +def test_delete_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_persistent_resource + ] = mock_rpc + request = {} + client.delete_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2489,6 +2823,56 @@ async def test_delete_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.DeletePersistentResourceRequest() +@pytest.mark.asyncio +async def test_delete_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_persistent_resource + ] = mock_object + + request = {} + await client.delete_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_persistent_resource_async( transport: str = "grpc_asyncio", @@ -2726,6 +3110,9 @@ def test_update_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.update_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2749,12 +3136,59 @@ def test_update_persistent_resource_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.update_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == persistent_resource_service.UpdatePersistentResourceRequest() +def test_update_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_persistent_resource + ] = mock_rpc + request = {} + client.update_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2778,6 +3212,56 @@ async def test_update_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.UpdatePersistentResourceRequest() +@pytest.mark.asyncio +async def test_update_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_persistent_resource + ] = mock_object + + request = {} + await client.update_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_persistent_resource_async( transport: str = "grpc_asyncio", @@ -3033,6 +3517,9 @@ def test_reboot_persistent_resource_empty_call(): with mock.patch.object( type(client.transport.reboot_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.reboot_persistent_resource() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3058,6 +3545,9 @@ def test_reboot_persistent_resource_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.reboot_persistent_resource), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.reboot_persistent_resource(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3066,6 +3556,50 @@ def test_reboot_persistent_resource_non_empty_request_with_auto_populated_field( ) +def test_reboot_persistent_resource_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.reboot_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.reboot_persistent_resource + ] = mock_rpc + request = {} + client.reboot_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.reboot_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_reboot_persistent_resource_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3089,6 +3623,56 @@ async def test_reboot_persistent_resource_empty_call_async(): assert args[0] == persistent_resource_service.RebootPersistentResourceRequest() +@pytest.mark.asyncio +async def test_reboot_persistent_resource_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PersistentResourceServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.reboot_persistent_resource + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.reboot_persistent_resource + ] = mock_object + + request = {} + await client.reboot_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.reboot_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_reboot_persistent_resource_async( transport: str = "grpc_asyncio", @@ -3447,6 +4031,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_persistent_resource + ] = mock_rpc + + request = {} + client.create_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_persistent_resource_rest_required_fields( request_type=persistent_resource_service.CreatePersistentResourceRequest, ): @@ -3759,6 +4388,47 @@ def test_get_persistent_resource_rest(request_type): assert response.reserved_ip_ranges == ["reserved_ip_ranges_value"] +def test_get_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_persistent_resource + ] = mock_rpc + + request = {} + client.get_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_persistent_resource_rest_required_fields( request_type=persistent_resource_service.GetPersistentResourceRequest, ): @@ -4035,6 +4705,47 @@ def test_list_persistent_resources_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_persistent_resources_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_persistent_resources + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_persistent_resources + ] = mock_rpc + + request = {} + client.list_persistent_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_persistent_resources(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_persistent_resources_rest_required_fields( request_type=persistent_resource_service.ListPersistentResourcesRequest, ): @@ -4387,6 +5098,51 @@ def test_delete_persistent_resource_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_persistent_resource + ] = mock_rpc + + request = {} + client.delete_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_persistent_resource_rest_required_fields( request_type=persistent_resource_service.DeletePersistentResourceRequest, ): @@ -4790,6 +5546,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_persistent_resource + ] = mock_rpc + + request = {} + client.update_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_persistent_resource_rest_required_fields( request_type=persistent_resource_service.UpdatePersistentResourceRequest, ): @@ -5074,6 +5875,51 @@ def test_reboot_persistent_resource_rest(request_type): assert response.operation.name == "operations/spam" +def test_reboot_persistent_resource_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PersistentResourceServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.reboot_persistent_resource + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.reboot_persistent_resource + ] = mock_rpc + + request = {} + client.reboot_persistent_resource(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.reboot_persistent_resource(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_reboot_persistent_resource_rest_required_fields( request_type=persistent_resource_service.RebootPersistentResourceRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 37feb13086..57acfe6058 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -1243,6 +1243,9 @@ def test_create_training_pipeline_empty_call(): with mock.patch.object( type(client.transport.create_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1268,6 +1271,9 @@ def test_create_training_pipeline_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_training_pipeline(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1276,6 +1282,46 @@ def test_create_training_pipeline_non_empty_request_with_auto_populated_field(): ) +def test_create_training_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_training_pipeline + ] = mock_rpc + request = {} + client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_training_pipeline_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1306,6 +1352,52 @@ async def test_create_training_pipeline_empty_call_async(): assert args[0] == pipeline_service.CreateTrainingPipelineRequest() +@pytest.mark.asyncio +async def test_create_training_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_training_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_training_pipeline + ] = mock_object + + request = {} + await client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_training_pipeline_async( transport: str = "grpc_asyncio", @@ -1579,6 +1671,9 @@ def test_get_training_pipeline_empty_call(): with mock.patch.object( type(client.transport.get_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1604,6 +1699,9 @@ def test_get_training_pipeline_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_training_pipeline(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1612,6 +1710,46 @@ def test_get_training_pipeline_non_empty_request_with_auto_populated_field(): ) +def test_get_training_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_training_pipeline + ] = mock_rpc + request = {} + client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_training_pipeline_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1642,6 +1780,52 @@ async def test_get_training_pipeline_empty_call_async(): assert args[0] == pipeline_service.GetTrainingPipelineRequest() +@pytest.mark.asyncio +async def test_get_training_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_training_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_training_pipeline + ] = mock_object + + request = {} + await client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_training_pipeline_async( transport: str = "grpc_asyncio", @@ -1895,6 +2079,9 @@ def test_list_training_pipelines_empty_call(): with mock.patch.object( type(client.transport.list_training_pipelines), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_training_pipelines() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1922,6 +2109,9 @@ def test_list_training_pipelines_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_training_pipelines), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_training_pipelines(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1932,6 +2122,46 @@ def test_list_training_pipelines_non_empty_request_with_auto_populated_field(): ) +def test_list_training_pipelines_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_training_pipelines + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_training_pipelines + ] = mock_rpc + request = {} + client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_training_pipelines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_training_pipelines_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1957,6 +2187,52 @@ async def test_list_training_pipelines_empty_call_async(): assert args[0] == pipeline_service.ListTrainingPipelinesRequest() +@pytest.mark.asyncio +async def test_list_training_pipelines_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_training_pipelines + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_training_pipelines + ] = mock_object + + request = {} + await client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_training_pipelines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_training_pipelines_async( transport: str = "grpc_asyncio", @@ -2395,6 +2671,9 @@ def test_delete_training_pipeline_empty_call(): with mock.patch.object( type(client.transport.delete_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2420,6 +2699,9 @@ def test_delete_training_pipeline_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_training_pipeline(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2428,6 +2710,50 @@ def test_delete_training_pipeline_non_empty_request_with_auto_populated_field(): ) +def test_delete_training_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_training_pipeline + ] = mock_rpc + request = {} + client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_training_pipeline_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2451,6 +2777,56 @@ async def test_delete_training_pipeline_empty_call_async(): assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() +@pytest.mark.asyncio +async def test_delete_training_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_training_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_training_pipeline + ] = mock_object + + request = {} + await client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_training_pipeline_async( transport: str = "grpc_asyncio", @@ -2688,6 +3064,9 @@ def test_cancel_training_pipeline_empty_call(): with mock.patch.object( type(client.transport.cancel_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_training_pipeline() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2713,6 +3092,9 @@ def test_cancel_training_pipeline_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_training_pipeline), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_training_pipeline(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2721,6 +3103,46 @@ def test_cancel_training_pipeline_non_empty_request_with_auto_populated_field(): ) +def test_cancel_training_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_training_pipeline + ] = mock_rpc + request = {} + client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_training_pipeline_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2742,6 +3164,52 @@ async def test_cancel_training_pipeline_empty_call_async(): assert args[0] == pipeline_service.CancelTrainingPipelineRequest() +@pytest.mark.asyncio +async def test_cancel_training_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_training_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_training_pipeline + ] = mock_object + + request = {} + await client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_training_pipeline_async( transport: str = "grpc_asyncio", @@ -2992,6 +3460,9 @@ def test_create_pipeline_job_empty_call(): with mock.patch.object( type(client.transport.create_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_pipeline_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3018,6 +3489,9 @@ def test_create_pipeline_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_pipeline_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3027,6 +3501,45 @@ def test_create_pipeline_job_non_empty_request_with_auto_populated_field(): ) +def test_create_pipeline_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_pipeline_job + ] = mock_rpc + request = {} + client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_pipeline_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3060,6 +3573,52 @@ async def test_create_pipeline_job_empty_call_async(): assert args[0] == pipeline_service.CreatePipelineJobRequest() +@pytest.mark.asyncio +async def test_create_pipeline_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_pipeline_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_pipeline_job + ] = mock_object + + request = {} + await client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_pipeline_job_async( transport: str = "grpc_asyncio", @@ -3351,6 +3910,9 @@ def test_get_pipeline_job_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_pipeline_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3374,6 +3936,9 @@ def test_get_pipeline_job_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_pipeline_job), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_pipeline_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3382,6 +3947,43 @@ def test_get_pipeline_job_non_empty_request_with_auto_populated_field(): ) +def test_get_pipeline_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_pipeline_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_pipeline_job + ] = mock_rpc + request = {} + client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_pipeline_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3413,6 +4015,52 @@ async def test_get_pipeline_job_empty_call_async(): assert args[0] == pipeline_service.GetPipelineJobRequest() +@pytest.mark.asyncio +async def test_get_pipeline_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_pipeline_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_pipeline_job + ] = mock_object + + request = {} + await client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_pipeline_job_async( transport: str = "grpc_asyncio", request_type=pipeline_service.GetPipelineJobRequest @@ -3661,6 +4309,9 @@ def test_list_pipeline_jobs_empty_call(): with mock.patch.object( type(client.transport.list_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_pipeline_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3689,6 +4340,9 @@ def test_list_pipeline_jobs_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_pipeline_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3700,6 +4354,45 @@ def test_list_pipeline_jobs_non_empty_request_with_auto_populated_field(): ) +def test_list_pipeline_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_pipeline_jobs in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_pipeline_jobs + ] = mock_rpc + request = {} + client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_pipeline_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3725,6 +4418,52 @@ async def test_list_pipeline_jobs_empty_call_async(): assert args[0] == pipeline_service.ListPipelineJobsRequest() +@pytest.mark.asyncio +async def test_list_pipeline_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_pipeline_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_pipeline_jobs + ] = mock_object + + request = {} + await client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_pipeline_jobs_async( transport: str = "grpc_asyncio", @@ -4163,6 +4902,9 @@ def test_delete_pipeline_job_empty_call(): with mock.patch.object( type(client.transport.delete_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_pipeline_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4188,6 +4930,9 @@ def test_delete_pipeline_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_pipeline_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4196,6 +4941,49 @@ def test_delete_pipeline_job_non_empty_request_with_auto_populated_field(): ) +def test_delete_pipeline_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_pipeline_job + ] = mock_rpc + request = {} + client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_pipeline_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4219,6 +5007,56 @@ async def test_delete_pipeline_job_empty_call_async(): assert args[0] == pipeline_service.DeletePipelineJobRequest() +@pytest.mark.asyncio +async def test_delete_pipeline_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_pipeline_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_pipeline_job + ] = mock_object + + request = {} + await client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_pipeline_job_async( transport: str = "grpc_asyncio", @@ -4456,6 +5294,9 @@ def test_batch_delete_pipeline_jobs_empty_call(): with mock.patch.object( type(client.transport.batch_delete_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_delete_pipeline_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4481,6 +5322,9 @@ def test_batch_delete_pipeline_jobs_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.batch_delete_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_delete_pipeline_jobs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4489,6 +5333,50 @@ def test_batch_delete_pipeline_jobs_non_empty_request_with_auto_populated_field( ) +def test_batch_delete_pipeline_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_delete_pipeline_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_delete_pipeline_jobs + ] = mock_rpc + request = {} + client.batch_delete_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_delete_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_delete_pipeline_jobs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4512,6 +5400,56 @@ async def test_batch_delete_pipeline_jobs_empty_call_async(): assert args[0] == pipeline_service.BatchDeletePipelineJobsRequest() +@pytest.mark.asyncio +async def test_batch_delete_pipeline_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_delete_pipeline_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_delete_pipeline_jobs + ] = mock_object + + request = {} + await client.batch_delete_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_delete_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_delete_pipeline_jobs_async( transport: str = "grpc_asyncio", @@ -4759,6 +5697,9 @@ def test_cancel_pipeline_job_empty_call(): with mock.patch.object( type(client.transport.cancel_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_pipeline_job() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4784,6 +5725,9 @@ def test_cancel_pipeline_job_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.cancel_pipeline_job), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.cancel_pipeline_job(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4792,6 +5736,45 @@ def test_cancel_pipeline_job_non_empty_request_with_auto_populated_field(): ) +def test_cancel_pipeline_job_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_pipeline_job + ] = mock_rpc + request = {} + client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_pipeline_job_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4813,6 +5796,52 @@ async def test_cancel_pipeline_job_empty_call_async(): assert args[0] == pipeline_service.CancelPipelineJobRequest() +@pytest.mark.asyncio +async def test_cancel_pipeline_job_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.cancel_pipeline_job + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.cancel_pipeline_job + ] = mock_object + + request = {} + await client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.cancel_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_cancel_pipeline_job_async( transport: str = "grpc_asyncio", @@ -5044,6 +6073,9 @@ def test_batch_cancel_pipeline_jobs_empty_call(): with mock.patch.object( type(client.transport.batch_cancel_pipeline_jobs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_cancel_pipeline_jobs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5065,16 +6097,63 @@ def test_batch_cancel_pipeline_jobs_non_empty_request_with_auto_populated_field( parent="parent_value", ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_cancel_pipeline_jobs), "__call__" - ) as call: - client.batch_cancel_pipeline_jobs(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == pipeline_service.BatchCancelPipelineJobsRequest( - parent="parent_value", + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_cancel_pipeline_jobs), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.batch_cancel_pipeline_jobs(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == pipeline_service.BatchCancelPipelineJobsRequest( + parent="parent_value", + ) + + +def test_batch_cancel_pipeline_jobs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_cancel_pipeline_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) + client._transport._wrapped_methods[ + client._transport.batch_cancel_pipeline_jobs + ] = mock_rpc + request = {} + client.batch_cancel_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_cancel_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 @pytest.mark.asyncio @@ -5100,6 +6179,56 @@ async def test_batch_cancel_pipeline_jobs_empty_call_async(): assert args[0] == pipeline_service.BatchCancelPipelineJobsRequest() +@pytest.mark.asyncio +async def test_batch_cancel_pipeline_jobs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PipelineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_cancel_pipeline_jobs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_cancel_pipeline_jobs + ] = mock_object + + request = {} + await client.batch_cancel_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.batch_cancel_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_cancel_pipeline_jobs_async( transport: str = "grpc_asyncio", @@ -5590,6 +6719,47 @@ def get_message_fields(field): assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED +def test_create_training_pipeline_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_training_pipeline + ] = mock_rpc + + request = {} + client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_training_pipeline_rest_required_fields( request_type=pipeline_service.CreateTrainingPipelineRequest, ): @@ -5880,6 +7050,47 @@ def test_get_training_pipeline_rest(request_type): assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED +def test_get_training_pipeline_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_training_pipeline + ] = mock_rpc + + request = {} + client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_training_pipeline_rest_required_fields( request_type=pipeline_service.GetTrainingPipelineRequest, ): @@ -6151,6 +7362,47 @@ def test_list_training_pipelines_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_training_pipelines_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_training_pipelines + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_training_pipelines + ] = mock_rpc + + request = {} + client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_training_pipelines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_training_pipelines_rest_required_fields( request_type=pipeline_service.ListTrainingPipelinesRequest, ): @@ -6495,6 +7747,51 @@ def test_delete_training_pipeline_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_training_pipeline_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_training_pipeline + ] = mock_rpc + + request = {} + client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_training_pipeline_rest_required_fields( request_type=pipeline_service.DeleteTrainingPipelineRequest, ): @@ -6760,6 +8057,47 @@ def test_cancel_training_pipeline_rest(request_type): assert response is None +def test_cancel_training_pipeline_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_training_pipeline + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_training_pipeline + ] = mock_rpc + + request = {} + client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_training_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_training_pipeline_rest_required_fields( request_type=pipeline_service.CancelTrainingPipelineRequest, ): @@ -7205,6 +8543,46 @@ def get_message_fields(field): assert response.preflight_validations is True +def test_create_pipeline_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_pipeline_job + ] = mock_rpc + + request = {} + client.create_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_pipeline_job_rest_required_fields( request_type=pipeline_service.CreatePipelineJobRequest, ): @@ -7503,6 +8881,44 @@ def test_get_pipeline_job_rest(request_type): assert response.preflight_validations is True +def test_get_pipeline_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_pipeline_job in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_pipeline_job + ] = mock_rpc + + request = {} + client.get_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_pipeline_job_rest_required_fields( request_type=pipeline_service.GetPipelineJobRequest, ): @@ -7772,6 +9188,46 @@ def test_list_pipeline_jobs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_pipeline_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_pipeline_jobs in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_pipeline_jobs + ] = mock_rpc + + request = {} + client.list_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_pipeline_jobs_rest_required_fields( request_type=pipeline_service.ListPipelineJobsRequest, ): @@ -8112,6 +9568,50 @@ def test_delete_pipeline_job_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_pipeline_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_pipeline_job + ] = mock_rpc + + request = {} + client.delete_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_pipeline_job_rest_required_fields( request_type=pipeline_service.DeletePipelineJobRequest, ): @@ -8373,6 +9873,51 @@ def test_batch_delete_pipeline_jobs_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_delete_pipeline_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_delete_pipeline_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_delete_pipeline_jobs + ] = mock_rpc + + request = {} + client.batch_delete_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_delete_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_delete_pipeline_jobs_rest_required_fields( request_type=pipeline_service.BatchDeletePipelineJobsRequest, ): @@ -8648,6 +10193,46 @@ def test_cancel_pipeline_job_rest(request_type): assert response is None +def test_cancel_pipeline_job_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.cancel_pipeline_job in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.cancel_pipeline_job + ] = mock_rpc + + request = {} + client.cancel_pipeline_job(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.cancel_pipeline_job(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_cancel_pipeline_job_rest_required_fields( request_type=pipeline_service.CancelPipelineJobRequest, ): @@ -8900,6 +10485,51 @@ def test_batch_cancel_pipeline_jobs_rest(request_type): assert response.operation.name == "operations/spam" +def test_batch_cancel_pipeline_jobs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PipelineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_cancel_pipeline_jobs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_cancel_pipeline_jobs + ] = mock_rpc + + request = {} + client.batch_cancel_pipeline_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.batch_cancel_pipeline_jobs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_cancel_pipeline_jobs_rest_required_fields( request_type=pipeline_service.BatchCancelPipelineJobsRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index 96cde1d559..3e2ee8aee6 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -1227,6 +1227,9 @@ def test_predict_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1250,6 +1253,9 @@ def test_predict_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1258,6 +1264,41 @@ def test_predict_non_empty_request_with_auto_populated_field(): ) +def test_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.predict] = mock_rpc + request = {} + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1284,6 +1325,50 @@ async def test_predict_empty_call_async(): assert args[0] == prediction_service.PredictRequest() +@pytest.mark.asyncio +async def test_predict_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.predict + ] = mock_object + + request = {} + await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_predict_async( transport: str = "grpc_asyncio", request_type=prediction_service.PredictRequest @@ -1471,6 +1556,9 @@ def test_raw_predict_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.raw_predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.raw_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1494,6 +1582,9 @@ def test_raw_predict_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.raw_predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.raw_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1502,6 +1593,41 @@ def test_raw_predict_non_empty_request_with_auto_populated_field(): ) +def test_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.raw_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.raw_predict] = mock_rpc + request = {} + client.raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_raw_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1526,6 +1652,52 @@ async def test_raw_predict_empty_call_async(): assert args[0] == prediction_service.RawPredictRequest() +@pytest.mark.asyncio +async def test_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.raw_predict + ] = mock_object + + request = {} + await client.raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_raw_predict_async( transport: str = "grpc_asyncio", request_type=prediction_service.RawPredictRequest @@ -1763,6 +1935,9 @@ def test_direct_predict_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.direct_predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.direct_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1786,6 +1961,9 @@ def test_direct_predict_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.direct_predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.direct_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1794,6 +1972,41 @@ def test_direct_predict_non_empty_request_with_auto_populated_field(): ) +def test_direct_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.direct_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.direct_predict] = mock_rpc + request = {} + client.direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_direct_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1815,6 +2028,52 @@ async def test_direct_predict_empty_call_async(): assert args[0] == prediction_service.DirectPredictRequest() +@pytest.mark.asyncio +async def test_direct_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.direct_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.direct_predict + ] = mock_object + + request = {} + await client.direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_direct_predict_async( transport: str = "grpc_asyncio", @@ -1963,6 +2222,9 @@ def test_direct_raw_predict_empty_call(): with mock.patch.object( type(client.transport.direct_raw_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.direct_raw_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1989,6 +2251,9 @@ def test_direct_raw_predict_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.direct_raw_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.direct_raw_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1998,6 +2263,45 @@ def test_direct_raw_predict_non_empty_request_with_auto_populated_field(): ) +def test_direct_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.direct_raw_predict in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.direct_raw_predict + ] = mock_rpc + request = {} + client.direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_direct_raw_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2023,6 +2327,52 @@ async def test_direct_raw_predict_empty_call_async(): assert args[0] == prediction_service.DirectRawPredictRequest() +@pytest.mark.asyncio +async def test_direct_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.direct_raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.direct_raw_predict + ] = mock_object + + request = {} + await client.direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_direct_raw_predict_async( transport: str = "grpc_asyncio", @@ -2166,6 +2516,92 @@ def test_stream_direct_predict(request_type, transport: str = "grpc"): assert isinstance(message, prediction_service.StreamDirectPredictResponse) +def test_stream_direct_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_direct_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_direct_predict + ] = mock_rpc + request = [{}] + client.stream_direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_stream_direct_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stream_direct_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stream_direct_predict + ] = mock_object + + request = [{}] + await client.stream_direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stream_direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stream_direct_predict_async( transport: str = "grpc_asyncio", @@ -2243,6 +2679,92 @@ def test_stream_direct_raw_predict(request_type, transport: str = "grpc"): assert isinstance(message, prediction_service.StreamDirectRawPredictResponse) +def test_stream_direct_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_direct_raw_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_direct_raw_predict + ] = mock_rpc + request = [{}] + client.stream_direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_stream_direct_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stream_direct_raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stream_direct_raw_predict + ] = mock_object + + request = [{}] + await client.stream_direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stream_direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stream_direct_raw_predict_async( transport: str = "grpc_asyncio", @@ -2320,6 +2842,89 @@ def test_streaming_predict(request_type, transport: str = "grpc"): assert isinstance(message, prediction_service.StreamingPredictResponse) +def test_streaming_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.streaming_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_predict + ] = mock_rpc + request = [{}] + client.streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_streaming_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.streaming_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.streaming_predict + ] = mock_object + + request = [{}] + await client.streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_predict_async( transport: str = "grpc_asyncio", @@ -2409,6 +3014,9 @@ def test_server_streaming_predict_empty_call(): with mock.patch.object( type(client.transport.server_streaming_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.server_streaming_predict() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2434,6 +3042,9 @@ def test_server_streaming_predict_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.server_streaming_predict), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.server_streaming_predict(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2442,8 +3053,48 @@ def test_server_streaming_predict_non_empty_request_with_auto_populated_field(): ) -@pytest.mark.asyncio -async def test_server_streaming_predict_empty_call_async(): +def test_server_streaming_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.server_streaming_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.server_streaming_predict + ] = mock_rpc + request = {} + client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.server_streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_server_streaming_predict_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = PredictionServiceAsyncClient( @@ -2466,6 +3117,52 @@ async def test_server_streaming_predict_empty_call_async(): assert args[0] == prediction_service.StreamingPredictRequest() +@pytest.mark.asyncio +async def test_server_streaming_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.server_streaming_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.server_streaming_predict + ] = mock_object + + request = {} + await client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.server_streaming_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_server_streaming_predict_async( transport: str = "grpc_asyncio", @@ -2609,6 +3306,92 @@ def test_streaming_raw_predict(request_type, transport: str = "grpc"): assert isinstance(message, prediction_service.StreamingRawPredictResponse) +def test_streaming_raw_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.streaming_raw_predict + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.streaming_raw_predict + ] = mock_rpc + request = [{}] + client.streaming_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.streaming_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_streaming_raw_predict_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.streaming_raw_predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.streaming_raw_predict + ] = mock_object + + request = [{}] + await client.streaming_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.streaming_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_streaming_raw_predict_async( transport: str = "grpc_asyncio", @@ -2696,6 +3479,9 @@ def test_explain_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.explain), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.explain() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2720,6 +3506,9 @@ def test_explain_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.explain), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.explain(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2729,6 +3518,41 @@ def test_explain_non_empty_request_with_auto_populated_field(): ) +def test_explain_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.explain in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.explain] = mock_rpc + request = {} + client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.explain(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_explain_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2752,6 +3576,50 @@ async def test_explain_empty_call_async(): assert args[0] == prediction_service.ExplainRequest() +@pytest.mark.asyncio +async def test_explain_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.explain + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.explain + ] = mock_object + + request = {} + await client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.explain(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_explain_async( transport: str = "grpc_asyncio", request_type=prediction_service.ExplainRequest @@ -2935,6 +3803,9 @@ def test_count_tokens_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.count_tokens() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2959,6 +3830,9 @@ def test_count_tokens_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.count_tokens(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2968,6 +3842,41 @@ def test_count_tokens_non_empty_request_with_auto_populated_field(): ) +def test_count_tokens_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.count_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.count_tokens] = mock_rpc + request = {} + client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_count_tokens_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2992,6 +3901,52 @@ async def test_count_tokens_empty_call_async(): assert args[0] == prediction_service.CountTokensRequest() +@pytest.mark.asyncio +async def test_count_tokens_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.count_tokens + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.count_tokens + ] = mock_object + + request = {} + await client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_count_tokens_async( transport: str = "grpc_asyncio", request_type=prediction_service.CountTokensRequest @@ -3229,6 +4184,9 @@ def test_generate_content_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.generate_content() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3252,6 +4210,9 @@ def test_generate_content_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.generate_content(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3260,6 +4221,43 @@ def test_generate_content_non_empty_request_with_auto_populated_field(): ) +def test_generate_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_content in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.generate_content + ] = mock_rpc + request = {} + client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_generate_content_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3281,6 +4279,52 @@ async def test_generate_content_empty_call_async(): assert args[0] == prediction_service.GenerateContentRequest() +@pytest.mark.asyncio +async def test_generate_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.generate_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.generate_content + ] = mock_object + + request = {} + await client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_generate_content_async( transport: str = "grpc_asyncio", @@ -3519,6 +4563,9 @@ def test_stream_generate_content_empty_call(): with mock.patch.object( type(client.transport.stream_generate_content), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stream_generate_content() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3544,6 +4591,9 @@ def test_stream_generate_content_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.stream_generate_content), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stream_generate_content(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3552,6 +4602,46 @@ def test_stream_generate_content_non_empty_request_with_auto_populated_field(): ) +def test_stream_generate_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_generate_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_generate_content + ] = mock_rpc + request = {} + client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_stream_generate_content_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3576,6 +4666,52 @@ async def test_stream_generate_content_empty_call_async(): assert args[0] == prediction_service.GenerateContentRequest() +@pytest.mark.asyncio +async def test_stream_generate_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stream_generate_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stream_generate_content + ] = mock_object + + request = {} + await client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stream_generate_content_async( transport: str = "grpc_asyncio", @@ -3780,311 +4916,29 @@ async def test_stream_generate_content_flattened_error_async(): @pytest.mark.parametrize( "request_type", [ - prediction_service.ChatCompletionsRequest, + prediction_service.PredictRequest, dict, ], ) -def test_chat_completions(request_type, transport: str = "grpc"): +def test_predict_rest(request_type): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + transport="rest", ) - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = iter([httpbody_pb2.HttpBody()]) - response = client.chat_completions(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - request = prediction_service.ChatCompletionsRequest() - assert args[0] == request - - # Establish that the response is the type that we expect. - for message in response: - assert isinstance(message, httpbody_pb2.HttpBody) + # send a request that will satisfy transcoding + request_init = {"endpoint": "projects/sample1/locations/sample2/endpoints/sample3"} + request = request_type(**request_init) - -def test_chat_completions_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - client.chat_completions() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == prediction_service.ChatCompletionsRequest() - - -def test_chat_completions_non_empty_request_with_auto_populated_field(): - # This test is a coverage failsafe to make sure that UUID4 fields are - # automatically populated, according to AIP-4235, with non-empty requests. - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc", - ) - - # Populate all string fields in the request which are not UUID4 - # since we want to check that UUID4 are populated automatically - # if they meet the requirements of AIP 4235. - request = prediction_service.ChatCompletionsRequest( - endpoint="endpoint_value", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - client.chat_completions(request=request) - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == prediction_service.ChatCompletionsRequest( - endpoint="endpoint_value", - ) - - -@pytest.mark.asyncio -async def test_chat_completions_empty_call_async(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = PredictionServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) - call.return_value.read = mock.AsyncMock(side_effect=[httpbody_pb2.HttpBody()]) - response = await client.chat_completions() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == prediction_service.ChatCompletionsRequest() - - -@pytest.mark.asyncio -async def test_chat_completions_async( - transport: str = "grpc_asyncio", - request_type=prediction_service.ChatCompletionsRequest, -): - client = PredictionServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) - call.return_value.read = mock.AsyncMock(side_effect=[httpbody_pb2.HttpBody()]) - response = await client.chat_completions(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - request = prediction_service.ChatCompletionsRequest() - assert args[0] == request - - # Establish that the response is the type that we expect. - message = await response.read() - assert isinstance(message, httpbody_pb2.HttpBody) - - -@pytest.mark.asyncio -async def test_chat_completions_async_from_dict(): - await test_chat_completions_async(request_type=dict) - - -def test_chat_completions_field_headers(): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = prediction_service.ChatCompletionsRequest() - - request.endpoint = "endpoint_value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - call.return_value = iter([httpbody_pb2.HttpBody()]) - client.chat_completions(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "endpoint=endpoint_value", - ) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_chat_completions_field_headers_async(): - client = PredictionServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = prediction_service.ChatCompletionsRequest() - - request.endpoint = "endpoint_value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) - call.return_value.read = mock.AsyncMock(side_effect=[httpbody_pb2.HttpBody()]) - await client.chat_completions(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "endpoint=endpoint_value", - ) in kw["metadata"] - - -def test_chat_completions_flattened(): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = iter([httpbody_pb2.HttpBody()]) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.chat_completions( - endpoint="endpoint_value", - http_body=httpbody_pb2.HttpBody(content_type="content_type_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].endpoint - mock_val = "endpoint_value" - assert arg == mock_val - arg = args[0].http_body - mock_val = httpbody_pb2.HttpBody(content_type="content_type_value") - assert arg == mock_val - - -def test_chat_completions_flattened_error(): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.chat_completions( - prediction_service.ChatCompletionsRequest(), - endpoint="endpoint_value", - http_body=httpbody_pb2.HttpBody(content_type="content_type_value"), - ) - - -@pytest.mark.asyncio -async def test_chat_completions_flattened_async(): - client = PredictionServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.chat_completions), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = iter([httpbody_pb2.HttpBody()]) - - call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.chat_completions( - endpoint="endpoint_value", - http_body=httpbody_pb2.HttpBody(content_type="content_type_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].endpoint - mock_val = "endpoint_value" - assert arg == mock_val - arg = args[0].http_body - mock_val = httpbody_pb2.HttpBody(content_type="content_type_value") - assert arg == mock_val - - -@pytest.mark.asyncio -async def test_chat_completions_flattened_error_async(): - client = PredictionServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.chat_completions( - prediction_service.ChatCompletionsRequest(), - endpoint="endpoint_value", - http_body=httpbody_pb2.HttpBody(content_type="content_type_value"), - ) - - -@pytest.mark.parametrize( - "request_type", - [ - prediction_service.PredictRequest, - dict, - ], -) -def test_predict_rest(request_type): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # send a request that will satisfy transcoding - request_init = {"endpoint": "projects/sample1/locations/sample2/endpoints/sample3"} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), "request") as req: - # Designate an appropriate value for the returned response. - return_value = prediction_service.PredictResponse( - deployed_model_id="deployed_model_id_value", - model="model_value", - model_version_id="model_version_id_value", - model_display_name="model_display_name_value", - ) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = prediction_service.PredictResponse( + deployed_model_id="deployed_model_id_value", + model="model_value", + model_version_id="model_version_id_value", + model_display_name="model_display_name_value", + ) # Wrap the value into a proper Response obj response_value = Response() @@ -4105,6 +4959,42 @@ def test_predict_rest(request_type): assert response.model_display_name == "model_display_name_value" +def test_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.predict] = mock_rpc + + request = {} + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_predict_rest_required_fields(request_type=prediction_service.PredictRequest): transport_class = transports.PredictionServiceRestTransport @@ -4385,6 +5275,42 @@ def test_raw_predict_rest(request_type): assert response.data == b"data_blob" +def test_raw_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.raw_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.raw_predict] = mock_rpc + + request = {} + client.raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_raw_predict_rest_required_fields( request_type=prediction_service.RawPredictRequest, ): @@ -4648,6 +5574,42 @@ def test_direct_predict_rest(request_type): assert isinstance(response, prediction_service.DirectPredictResponse) +def test_direct_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.direct_predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.direct_predict] = mock_rpc + + request = {} + client.direct_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.direct_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_direct_predict_rest_required_fields( request_type=prediction_service.DirectPredictRequest, ): @@ -4859,6 +5821,46 @@ def test_direct_raw_predict_rest(request_type): assert response.output == b"output_blob" +def test_direct_raw_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.direct_raw_predict in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.direct_raw_predict + ] = mock_rpc + + request = {} + client.direct_raw_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.direct_raw_predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_direct_raw_predict_rest_required_fields( request_type=prediction_service.DirectRawPredictRequest, ): @@ -5107,233 +6109,50 @@ def test_server_streaming_predict_rest(request_type): assert isinstance(response, prediction_service.StreamingPredictResponse) -def test_server_streaming_predict_rest_required_fields( - request_type=prediction_service.StreamingPredictRequest, -): - transport_class = transports.PredictionServiceRestTransport - - request_init = {} - request_init["endpoint"] = "" - request = request_type(**request_init) - pb_request = request_type.pb(request) - jsonified_request = json.loads( - json_format.MessageToJson(pb_request, use_integers_for_enums=False) - ) - - # verify fields with default values are dropped - - unset_fields = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ).server_streaming_predict._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with default values are now present - - jsonified_request["endpoint"] = "endpoint_value" - - unset_fields = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ).server_streaming_predict._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with non-default values are left alone - assert "endpoint" in jsonified_request - assert jsonified_request["endpoint"] == "endpoint_value" - - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - request = request_type(**request_init) - - # Designate an appropriate value for the returned response. - return_value = prediction_service.StreamingPredictResponse() - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # We need to mock transcode() because providing default values - # for required fields will fail the real version if the http_options - # expect actual values for those fields. - with mock.patch.object(path_template, "transcode") as transcode: - # A uri without fields and an empty body will force all the - # request fields to show up in the query_params. - pb_request = request_type.pb(request) - transcode_result = { - "uri": "v1/sample_method", - "method": "post", - "query_params": pb_request, - } - transcode_result["body"] = pb_request - transcode.return_value = transcode_result - - response_value = Response() - response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = prediction_service.StreamingPredictResponse.pb(return_value) - json_return_value = json_format.MessageToJson(return_value) - json_return_value = "[{}]".format(json_return_value) - - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - - with mock.patch.object(response_value, "iter_content") as iter_content: - iter_content.return_value = iter(json_return_value) - response = client.server_streaming_predict(request) - - expected_params = [] - actual_params = req.call_args.kwargs["params"] - assert expected_params == actual_params - - -def test_server_streaming_predict_rest_unset_required_fields(): - transport = transports.PredictionServiceRestTransport( - credentials=ga_credentials.AnonymousCredentials - ) - - unset_fields = transport.server_streaming_predict._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("endpoint",))) - - -@pytest.mark.parametrize("null_interceptor", [True, False]) -def test_server_streaming_predict_rest_interceptors(null_interceptor): - transport = transports.PredictionServiceRestTransport( - credentials=ga_credentials.AnonymousCredentials(), - interceptor=None - if null_interceptor - else transports.PredictionServiceRestInterceptor(), - ) - client = PredictionServiceClient(transport=transport) - with mock.patch.object( - type(client.transport._session), "request" - ) as req, mock.patch.object( - path_template, "transcode" - ) as transcode, mock.patch.object( - transports.PredictionServiceRestInterceptor, "post_server_streaming_predict" - ) as post, mock.patch.object( - transports.PredictionServiceRestInterceptor, "pre_server_streaming_predict" - ) as pre: - pre.assert_not_called() - post.assert_not_called() - pb_message = prediction_service.StreamingPredictRequest.pb( - prediction_service.StreamingPredictRequest() - ) - transcode.return_value = { - "method": "post", - "uri": "my_uri", - "body": pb_message, - "query_params": pb_message, - } - - req.return_value = Response() - req.return_value.status_code = 200 - req.return_value.request = PreparedRequest() - req.return_value._content = prediction_service.StreamingPredictResponse.to_json( - prediction_service.StreamingPredictResponse() +def test_server_streaming_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - req.return_value._content = "[{}]".format(req.return_value._content) - request = prediction_service.StreamingPredictRequest() - metadata = [ - ("key", "val"), - ("cephalopod", "squid"), - ] - pre.return_value = request, metadata - post.return_value = prediction_service.StreamingPredictResponse() + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() - client.server_streaming_predict( - request, - metadata=[ - ("key", "val"), - ("cephalopod", "squid"), - ], + # Ensure method has been cached + assert ( + client._transport.server_streaming_predict + in client._transport._wrapped_methods ) - pre.assert_called_once() - post.assert_called_once() - - -def test_server_streaming_predict_rest_bad_request( - transport: str = "rest", request_type=prediction_service.StreamingPredictRequest -): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # send a request that will satisfy transcoding - request_init = {"endpoint": "projects/sample1/locations/sample2/endpoints/sample3"} - request = request_type(**request_init) + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.server_streaming_predict + ] = mock_rpc - # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, "request") as req, pytest.raises( - core_exceptions.BadRequest - ): - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 400 - response_value.request = Request() - req.return_value = response_value + request = {} client.server_streaming_predict(request) + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 -def test_server_streaming_predict_rest_error(): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), transport="rest" - ) - - -def test_streaming_raw_predict_rest_no_http_options(): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - request = prediction_service.StreamingRawPredictRequest() - requests = [request] - with pytest.raises(RuntimeError): - client.streaming_raw_predict(requests) - - -@pytest.mark.parametrize( - "request_type", - [ - prediction_service.ExplainRequest, - dict, - ], -) -def test_explain_rest(request_type): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # send a request that will satisfy transcoding - request_init = {"endpoint": "projects/sample1/locations/sample2/endpoints/sample3"} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), "request") as req: - # Designate an appropriate value for the returned response. - return_value = prediction_service.ExplainResponse( - deployed_model_id="deployed_model_id_value", - ) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - # Convert return value to protobuf type - return_value = prediction_service.ExplainResponse.pb(return_value) - json_return_value = json_format.MessageToJson(return_value) - - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - response = client.explain(request) + client.server_streaming_predict(request) - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.ExplainResponse) - assert response.deployed_model_id == "deployed_model_id_value" + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 -def test_explain_rest_required_fields(request_type=prediction_service.ExplainRequest): +def test_server_streaming_predict_rest_required_fields( + request_type=prediction_service.StreamingPredictRequest, +): transport_class = transports.PredictionServiceRestTransport request_init = {} @@ -5348,7 +6167,7 @@ def test_explain_rest_required_fields(request_type=prediction_service.ExplainReq unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).explain._get_unset_required_fields(jsonified_request) + ).server_streaming_predict._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -5357,7 +6176,7 @@ def test_explain_rest_required_fields(request_type=prediction_service.ExplainReq unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).explain._get_unset_required_fields(jsonified_request) + ).server_streaming_predict._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -5371,7 +6190,7 @@ def test_explain_rest_required_fields(request_type=prediction_service.ExplainReq request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = prediction_service.ExplainResponse() + return_value = prediction_service.StreamingPredictResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5393,38 +6212,33 @@ def test_explain_rest_required_fields(request_type=prediction_service.ExplainReq response_value.status_code = 200 # Convert return value to protobuf type - return_value = prediction_service.ExplainResponse.pb(return_value) + return_value = prediction_service.StreamingPredictResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.explain(request) + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + response = client.server_streaming_predict(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_explain_rest_unset_required_fields(): +def test_server_streaming_predict_rest_unset_required_fields(): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.explain._get_unset_required_fields({}) - assert set(unset_fields) == ( - set(()) - & set( - ( - "endpoint", - "instances", - ) - ) - ) + unset_fields = transport.server_streaming_predict._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("endpoint",))) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_explain_rest_interceptors(null_interceptor): +def test_server_streaming_predict_rest_interceptors(null_interceptor): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -5437,14 +6251,14 @@ def test_explain_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.PredictionServiceRestInterceptor, "post_explain" + transports.PredictionServiceRestInterceptor, "post_server_streaming_predict" ) as post, mock.patch.object( - transports.PredictionServiceRestInterceptor, "pre_explain" + transports.PredictionServiceRestInterceptor, "pre_server_streaming_predict" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = prediction_service.ExplainRequest.pb( - prediction_service.ExplainRequest() + pb_message = prediction_service.StreamingPredictRequest.pb( + prediction_service.StreamingPredictRequest() ) transcode.return_value = { "method": "post", @@ -5456,19 +6270,20 @@ def test_explain_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = prediction_service.ExplainResponse.to_json( - prediction_service.ExplainResponse() + req.return_value._content = prediction_service.StreamingPredictResponse.to_json( + prediction_service.StreamingPredictResponse() ) + req.return_value._content = "[{}]".format(req.return_value._content) - request = prediction_service.ExplainRequest() + request = prediction_service.StreamingPredictRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = prediction_service.ExplainResponse() + post.return_value = prediction_service.StreamingPredictResponse() - client.explain( + client.server_streaming_predict( request, metadata=[ ("key", "val"), @@ -5480,8 +6295,8 @@ def test_explain_rest_interceptors(null_interceptor): post.assert_called_once() -def test_explain_rest_bad_request( - transport: str = "rest", request_type=prediction_service.ExplainRequest +def test_server_streaming_predict_rest_bad_request( + transport: str = "rest", request_type=prediction_service.StreamingPredictRequest ): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5501,88 +6316,34 @@ def test_explain_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.explain(request) + client.server_streaming_predict(request) -def test_explain_rest_flattened(): +def test_server_streaming_predict_rest_error(): client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), "request") as req: - # Designate an appropriate value for the returned response. - return_value = prediction_service.ExplainResponse() - - # get arguments that satisfy an http rule for this method - sample_request = { - "endpoint": "projects/sample1/locations/sample2/endpoints/sample3" - } - - # get truthy value for each flattened field - mock_args = dict( - endpoint="endpoint_value", - instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - mock_args.update(sample_request) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - # Convert return value to protobuf type - return_value = prediction_service.ExplainResponse.pb(return_value) - json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode("UTF-8") - req.return_value = response_value - - client.explain(**mock_args) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(req.mock_calls) == 1 - _, args, _ = req.mock_calls[0] - assert path_template.validate( - "%s/v1beta1/{endpoint=projects/*/locations/*/endpoints/*}:explain" - % client.transport._host, - args[1], - ) - -def test_explain_rest_flattened_error(transport: str = "rest"): +def test_streaming_raw_predict_rest_no_http_options(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.explain( - prediction_service.ExplainRequest(), - endpoint="endpoint_value", - instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - -def test_explain_rest_error(): - client = PredictionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), transport="rest" + transport="rest", ) + request = prediction_service.StreamingRawPredictRequest() + requests = [request] + with pytest.raises(RuntimeError): + client.streaming_raw_predict(requests) @pytest.mark.parametrize( "request_type", [ - prediction_service.CountTokensRequest, + prediction_service.ExplainRequest, dict, ], ) -def test_count_tokens_rest(request_type): +def test_explain_rest(request_type): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5595,36 +6356,67 @@ def test_count_tokens_rest(request_type): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = prediction_service.CountTokensResponse( - total_tokens=1303, - total_billable_characters=2617, + return_value = prediction_service.ExplainResponse( + deployed_model_id="deployed_model_id_value", ) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = prediction_service.CountTokensResponse.pb(return_value) + return_value = prediction_service.ExplainResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.count_tokens(request) + response = client.explain(request) # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.CountTokensResponse) - assert response.total_tokens == 1303 - assert response.total_billable_characters == 2617 + assert isinstance(response, prediction_service.ExplainResponse) + assert response.deployed_model_id == "deployed_model_id_value" -def test_count_tokens_rest_required_fields( - request_type=prediction_service.CountTokensRequest, -): +def test_explain_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.explain in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.explain] = mock_rpc + + request = {} + client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.explain(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_explain_rest_required_fields(request_type=prediction_service.ExplainRequest): transport_class = transports.PredictionServiceRestTransport request_init = {} request_init["endpoint"] = "" - request_init["model"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -5635,24 +6427,21 @@ def test_count_tokens_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).count_tokens._get_unset_required_fields(jsonified_request) + ).explain._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present jsonified_request["endpoint"] = "endpoint_value" - jsonified_request["model"] = "model_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).count_tokens._get_unset_required_fields(jsonified_request) + ).explain._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "endpoint" in jsonified_request assert jsonified_request["endpoint"] == "endpoint_value" - assert "model" in jsonified_request - assert jsonified_request["model"] == "model_value" client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5661,7 +6450,7 @@ def test_count_tokens_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = prediction_service.CountTokensResponse() + return_value = prediction_service.ExplainResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5683,40 +6472,38 @@ def test_count_tokens_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = prediction_service.CountTokensResponse.pb(return_value) + return_value = prediction_service.ExplainResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.count_tokens(request) + response = client.explain(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_count_tokens_rest_unset_required_fields(): +def test_explain_rest_unset_required_fields(): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.count_tokens._get_unset_required_fields({}) + unset_fields = transport.explain._get_unset_required_fields({}) assert set(unset_fields) == ( set(()) & set( ( "endpoint", - "model", "instances", - "contents", ) ) ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_count_tokens_rest_interceptors(null_interceptor): +def test_explain_rest_interceptors(null_interceptor): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -5729,14 +6516,14 @@ def test_count_tokens_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.PredictionServiceRestInterceptor, "post_count_tokens" + transports.PredictionServiceRestInterceptor, "post_explain" ) as post, mock.patch.object( - transports.PredictionServiceRestInterceptor, "pre_count_tokens" + transports.PredictionServiceRestInterceptor, "pre_explain" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = prediction_service.CountTokensRequest.pb( - prediction_service.CountTokensRequest() + pb_message = prediction_service.ExplainRequest.pb( + prediction_service.ExplainRequest() ) transcode.return_value = { "method": "post", @@ -5748,19 +6535,19 @@ def test_count_tokens_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = prediction_service.CountTokensResponse.to_json( - prediction_service.CountTokensResponse() + req.return_value._content = prediction_service.ExplainResponse.to_json( + prediction_service.ExplainResponse() ) - request = prediction_service.CountTokensRequest() + request = prediction_service.ExplainRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = prediction_service.CountTokensResponse() + post.return_value = prediction_service.ExplainResponse() - client.count_tokens( + client.explain( request, metadata=[ ("key", "val"), @@ -5772,8 +6559,8 @@ def test_count_tokens_rest_interceptors(null_interceptor): post.assert_called_once() -def test_count_tokens_rest_bad_request( - transport: str = "rest", request_type=prediction_service.CountTokensRequest +def test_explain_rest_bad_request( + transport: str = "rest", request_type=prediction_service.ExplainRequest ): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5793,10 +6580,10 @@ def test_count_tokens_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.count_tokens(request) + client.explain(request) -def test_count_tokens_rest_flattened(): +def test_explain_rest_flattened(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5805,7 +6592,7 @@ def test_count_tokens_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = prediction_service.CountTokensResponse() + return_value = prediction_service.ExplainResponse() # get arguments that satisfy an http rule for this method sample_request = { @@ -5816,6 +6603,8 @@ def test_count_tokens_rest_flattened(): mock_args = dict( endpoint="endpoint_value", instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), + deployed_model_id="deployed_model_id_value", ) mock_args.update(sample_request) @@ -5823,25 +6612,25 @@ def test_count_tokens_rest_flattened(): response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = prediction_service.CountTokensResponse.pb(return_value) + return_value = prediction_service.ExplainResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.count_tokens(**mock_args) + client.explain(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{endpoint=projects/*/locations/*/endpoints/*}:countTokens" + "%s/v1beta1/{endpoint=projects/*/locations/*/endpoints/*}:explain" % client.transport._host, args[1], ) -def test_count_tokens_rest_flattened_error(transport: str = "rest"): +def test_explain_rest_flattened_error(transport: str = "rest"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5850,14 +6639,16 @@ def test_count_tokens_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.count_tokens( - prediction_service.CountTokensRequest(), + client.explain( + prediction_service.ExplainRequest(), endpoint="endpoint_value", instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), + deployed_model_id="deployed_model_id_value", ) -def test_count_tokens_rest_error(): +def test_explain_rest_error(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -5866,46 +6657,88 @@ def test_count_tokens_rest_error(): @pytest.mark.parametrize( "request_type", [ - prediction_service.GenerateContentRequest, + prediction_service.CountTokensRequest, dict, ], ) -def test_generate_content_rest(request_type): +def test_count_tokens_rest(request_type): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"model": "projects/sample1/locations/sample2/endpoints/sample3"} + request_init = {"endpoint": "projects/sample1/locations/sample2/endpoints/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = prediction_service.GenerateContentResponse() + return_value = prediction_service.CountTokensResponse( + total_tokens=1303, + total_billable_characters=2617, + ) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = prediction_service.GenerateContentResponse.pb(return_value) + return_value = prediction_service.CountTokensResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.generate_content(request) + response = client.count_tokens(request) # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.GenerateContentResponse) + assert isinstance(response, prediction_service.CountTokensResponse) + assert response.total_tokens == 1303 + assert response.total_billable_characters == 2617 -def test_generate_content_rest_required_fields( - request_type=prediction_service.GenerateContentRequest, +def test_count_tokens_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.count_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.count_tokens] = mock_rpc + + request = {} + client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_count_tokens_rest_required_fields( + request_type=prediction_service.CountTokensRequest, ): transport_class = transports.PredictionServiceRestTransport request_init = {} + request_init["endpoint"] = "" request_init["model"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) @@ -5917,19 +6750,22 @@ def test_generate_content_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).generate_content._get_unset_required_fields(jsonified_request) + ).count_tokens._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present + jsonified_request["endpoint"] = "endpoint_value" jsonified_request["model"] = "model_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).generate_content._get_unset_required_fields(jsonified_request) + ).count_tokens._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone + assert "endpoint" in jsonified_request + assert jsonified_request["endpoint"] == "endpoint_value" assert "model" in jsonified_request assert jsonified_request["model"] == "model_value" @@ -5940,7 +6776,7 @@ def test_generate_content_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = prediction_service.GenerateContentResponse() + return_value = prediction_service.CountTokensResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5962,30 +6798,32 @@ def test_generate_content_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = prediction_service.GenerateContentResponse.pb(return_value) + return_value = prediction_service.CountTokensResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - response = client.generate_content(request) + response = client.count_tokens(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_generate_content_rest_unset_required_fields(): +def test_count_tokens_rest_unset_required_fields(): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.generate_content._get_unset_required_fields({}) + unset_fields = transport.count_tokens._get_unset_required_fields({}) assert set(unset_fields) == ( set(()) & set( ( + "endpoint", "model", + "instances", "contents", ) ) @@ -5993,7 +6831,7 @@ def test_generate_content_rest_unset_required_fields(): @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_generate_content_rest_interceptors(null_interceptor): +def test_count_tokens_rest_interceptors(null_interceptor): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -6006,14 +6844,14 @@ def test_generate_content_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.PredictionServiceRestInterceptor, "post_generate_content" + transports.PredictionServiceRestInterceptor, "post_count_tokens" ) as post, mock.patch.object( - transports.PredictionServiceRestInterceptor, "pre_generate_content" + transports.PredictionServiceRestInterceptor, "pre_count_tokens" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = prediction_service.GenerateContentRequest.pb( - prediction_service.GenerateContentRequest() + pb_message = prediction_service.CountTokensRequest.pb( + prediction_service.CountTokensRequest() ) transcode.return_value = { "method": "post", @@ -6025,19 +6863,19 @@ def test_generate_content_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = prediction_service.GenerateContentResponse.to_json( - prediction_service.GenerateContentResponse() + req.return_value._content = prediction_service.CountTokensResponse.to_json( + prediction_service.CountTokensResponse() ) - request = prediction_service.GenerateContentRequest() + request = prediction_service.CountTokensRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = prediction_service.GenerateContentResponse() + post.return_value = prediction_service.CountTokensResponse() - client.generate_content( + client.count_tokens( request, metadata=[ ("key", "val"), @@ -6049,8 +6887,8 @@ def test_generate_content_rest_interceptors(null_interceptor): post.assert_called_once() -def test_generate_content_rest_bad_request( - transport: str = "rest", request_type=prediction_service.GenerateContentRequest +def test_count_tokens_rest_bad_request( + transport: str = "rest", request_type=prediction_service.CountTokensRequest ): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6058,7 +6896,7 @@ def test_generate_content_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"model": "projects/sample1/locations/sample2/endpoints/sample3"} + request_init = {"endpoint": "projects/sample1/locations/sample2/endpoints/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -6070,10 +6908,10 @@ def test_generate_content_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.generate_content(request) + client.count_tokens(request) -def test_generate_content_rest_flattened(): +def test_count_tokens_rest_flattened(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6082,17 +6920,17 @@ def test_generate_content_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = prediction_service.GenerateContentResponse() + return_value = prediction_service.CountTokensResponse() # get arguments that satisfy an http rule for this method sample_request = { - "model": "projects/sample1/locations/sample2/endpoints/sample3" + "endpoint": "projects/sample1/locations/sample2/endpoints/sample3" } # get truthy value for each flattened field mock_args = dict( - model="model_value", - contents=[content.Content(role="role_value")], + endpoint="endpoint_value", + instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], ) mock_args.update(sample_request) @@ -6100,25 +6938,25 @@ def test_generate_content_rest_flattened(): response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = prediction_service.GenerateContentResponse.pb(return_value) + return_value = prediction_service.CountTokensResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - client.generate_content(**mock_args) + client.count_tokens(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{model=projects/*/locations/*/endpoints/*}:generateContent" + "%s/v1beta1/{endpoint=projects/*/locations/*/endpoints/*}:countTokens" % client.transport._host, args[1], ) -def test_generate_content_rest_flattened_error(transport: str = "rest"): +def test_count_tokens_rest_flattened_error(transport: str = "rest"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6127,14 +6965,14 @@ def test_generate_content_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.generate_content( - prediction_service.GenerateContentRequest(), - model="model_value", - contents=[content.Content(role="role_value")], + client.count_tokens( + prediction_service.CountTokensRequest(), + endpoint="endpoint_value", + instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], ) -def test_generate_content_rest_error(): +def test_count_tokens_rest_error(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -6147,7 +6985,7 @@ def test_generate_content_rest_error(): dict, ], ) -def test_stream_generate_content_rest(request_type): +def test_generate_content_rest(request_type): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6169,22 +7007,53 @@ def test_stream_generate_content_rest(request_type): return_value = prediction_service.GenerateContentResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) - json_return_value = "[{}]".format(json_return_value) - response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - with mock.patch.object(response_value, "iter_content") as iter_content: - iter_content.return_value = iter(json_return_value) - response = client.stream_generate_content(request) - - assert isinstance(response, Iterable) - response = next(response) + response = client.generate_content(request) # Establish that the response is the type that we expect. assert isinstance(response, prediction_service.GenerateContentResponse) -def test_stream_generate_content_rest_required_fields( +def test_generate_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_content in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.generate_content + ] = mock_rpc + + request = {} + client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_generate_content_rest_required_fields( request_type=prediction_service.GenerateContentRequest, ): transport_class = transports.PredictionServiceRestTransport @@ -6201,7 +7070,7 @@ def test_stream_generate_content_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).stream_generate_content._get_unset_required_fields(jsonified_request) + ).generate_content._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -6210,7 +7079,7 @@ def test_stream_generate_content_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).stream_generate_content._get_unset_required_fields(jsonified_request) + ).generate_content._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -6248,26 +7117,23 @@ def test_stream_generate_content_rest_required_fields( # Convert return value to protobuf type return_value = prediction_service.GenerateContentResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) - json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - with mock.patch.object(response_value, "iter_content") as iter_content: - iter_content.return_value = iter(json_return_value) - response = client.stream_generate_content(request) + response = client.generate_content(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_stream_generate_content_rest_unset_required_fields(): +def test_generate_content_rest_unset_required_fields(): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.stream_generate_content._get_unset_required_fields({}) + unset_fields = transport.generate_content._get_unset_required_fields({}) assert set(unset_fields) == ( set(()) & set( @@ -6280,7 +7146,7 @@ def test_stream_generate_content_rest_unset_required_fields(): @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_stream_generate_content_rest_interceptors(null_interceptor): +def test_generate_content_rest_interceptors(null_interceptor): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -6293,9 +7159,9 @@ def test_stream_generate_content_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.PredictionServiceRestInterceptor, "post_stream_generate_content" + transports.PredictionServiceRestInterceptor, "post_generate_content" ) as post, mock.patch.object( - transports.PredictionServiceRestInterceptor, "pre_stream_generate_content" + transports.PredictionServiceRestInterceptor, "pre_generate_content" ) as pre: pre.assert_not_called() post.assert_not_called() @@ -6315,7 +7181,6 @@ def test_stream_generate_content_rest_interceptors(null_interceptor): req.return_value._content = prediction_service.GenerateContentResponse.to_json( prediction_service.GenerateContentResponse() ) - req.return_value._content = "[{}]".format(req.return_value._content) request = prediction_service.GenerateContentRequest() metadata = [ @@ -6325,7 +7190,7 @@ def test_stream_generate_content_rest_interceptors(null_interceptor): pre.return_value = request, metadata post.return_value = prediction_service.GenerateContentResponse() - client.stream_generate_content( + client.generate_content( request, metadata=[ ("key", "val"), @@ -6337,7 +7202,7 @@ def test_stream_generate_content_rest_interceptors(null_interceptor): post.assert_called_once() -def test_stream_generate_content_rest_bad_request( +def test_generate_content_rest_bad_request( transport: str = "rest", request_type=prediction_service.GenerateContentRequest ): client = PredictionServiceClient( @@ -6358,10 +7223,10 @@ def test_stream_generate_content_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.stream_generate_content(request) + client.generate_content(request) -def test_stream_generate_content_rest_flattened(): +def test_generate_content_rest_flattened(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6390,26 +7255,23 @@ def test_stream_generate_content_rest_flattened(): # Convert return value to protobuf type return_value = prediction_service.GenerateContentResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) - json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value - with mock.patch.object(response_value, "iter_content") as iter_content: - iter_content.return_value = iter(json_return_value) - client.stream_generate_content(**mock_args) + client.generate_content(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{model=projects/*/locations/*/endpoints/*}:streamGenerateContent" + "%s/v1beta1/{model=projects/*/locations/*/endpoints/*}:generateContent" % client.transport._host, args[1], ) -def test_stream_generate_content_rest_flattened_error(transport: str = "rest"): +def test_generate_content_rest_flattened_error(transport: str = "rest"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6418,14 +7280,14 @@ def test_stream_generate_content_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.stream_generate_content( + client.generate_content( prediction_service.GenerateContentRequest(), model="model_value", contents=[content.Content(role="role_value")], ) -def test_stream_generate_content_rest_error(): +def test_generate_content_rest_error(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -6434,108 +7296,30 @@ def test_stream_generate_content_rest_error(): @pytest.mark.parametrize( "request_type", [ - prediction_service.ChatCompletionsRequest, + prediction_service.GenerateContentRequest, dict, ], ) -def test_chat_completions_rest(request_type): +def test_stream_generate_content_rest(request_type): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", ) # send a request that will satisfy transcoding - request_init = {"endpoint": "projects/sample1/locations/sample2/endpoints/sample3"} - request_init["http_body"] = { - "content_type": "content_type_value", - "data": b"data_blob", - "extensions": [ - { - "type_url": "type.googleapis.com/google.protobuf.Duration", - "value": b"\x08\x0c\x10\xdb\x07", - } - ], - } - # The version of a generated dependency at test runtime may differ from the version used during generation. - # Delete any fields which are not present in the current runtime dependency - # See https://github.com/googleapis/gapic-generator-python/issues/1748 - - # Determine if the message type is proto-plus or protobuf - test_field = prediction_service.ChatCompletionsRequest.meta.fields["http_body"] - - def get_message_fields(field): - # Given a field which is a message (composite type), return a list with - # all the fields of the message. - # If the field is not a composite type, return an empty list. - message_fields = [] - - if hasattr(field, "message") and field.message: - is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") - - if is_field_type_proto_plus_type: - message_fields = field.message.meta.fields.values() - # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types - else: # pragma: NO COVER - message_fields = field.message.DESCRIPTOR.fields - return message_fields - - runtime_nested_fields = [ - (field.name, nested_field.name) - for field in get_message_fields(test_field) - for nested_field in get_message_fields(field) - ] - - subfields_not_in_runtime = [] - - # For each item in the sample request, create a list of sub fields which are not present at runtime - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for field, value in request_init["http_body"].items(): # pragma: NO COVER - result = None - is_repeated = False - # For repeated fields - if isinstance(value, list) and len(value): - is_repeated = True - result = value[0] - # For fields where the type is another message - if isinstance(value, dict): - result = value - - if result and hasattr(result, "keys"): - for subfield in result.keys(): - if (field, subfield) not in runtime_nested_fields: - subfields_not_in_runtime.append( - { - "field": field, - "subfield": subfield, - "is_repeated": is_repeated, - } - ) - - # Remove fields from the sample request which are not present in the runtime version of the dependency - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER - field = subfield_to_delete.get("field") - field_repeated = subfield_to_delete.get("is_repeated") - subfield = subfield_to_delete.get("subfield") - if subfield: - if field_repeated: - for i in range(0, len(request_init["http_body"][field])): - del request_init["http_body"][field][i][subfield] - else: - del request_init["http_body"][field][subfield] + request_init = {"model": "projects/sample1/locations/sample2/endpoints/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = httpbody_pb2.HttpBody( - content_type="content_type_value", - data=b"data_blob", - ) + return_value = prediction_service.GenerateContentResponse() # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = prediction_service.GenerateContentResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -6544,24 +7328,63 @@ def get_message_fields(field): req.return_value = response_value with mock.patch.object(response_value, "iter_content") as iter_content: iter_content.return_value = iter(json_return_value) - response = client.chat_completions(request) + response = client.stream_generate_content(request) assert isinstance(response, Iterable) response = next(response) # Establish that the response is the type that we expect. - assert isinstance(response, httpbody_pb2.HttpBody) - assert response.content_type == "content_type_value" - assert response.data == b"data_blob" + assert isinstance(response, prediction_service.GenerateContentResponse) + + +def test_stream_generate_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_generate_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_generate_content + ] = mock_rpc + + request = {} + client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 -def test_chat_completions_rest_required_fields( - request_type=prediction_service.ChatCompletionsRequest, + +def test_stream_generate_content_rest_required_fields( + request_type=prediction_service.GenerateContentRequest, ): transport_class = transports.PredictionServiceRestTransport request_init = {} - request_init["endpoint"] = "" + request_init["model"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -6572,21 +7395,21 @@ def test_chat_completions_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).chat_completions._get_unset_required_fields(jsonified_request) + ).stream_generate_content._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["endpoint"] = "endpoint_value" + jsonified_request["model"] = "model_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).chat_completions._get_unset_required_fields(jsonified_request) + ).stream_generate_content._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "endpoint" in jsonified_request - assert jsonified_request["endpoint"] == "endpoint_value" + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6595,7 +7418,7 @@ def test_chat_completions_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = httpbody_pb2.HttpBody() + return_value = prediction_service.GenerateContentResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -6616,6 +7439,8 @@ def test_chat_completions_rest_required_fields( response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = prediction_service.GenerateContentResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -6624,24 +7449,32 @@ def test_chat_completions_rest_required_fields( with mock.patch.object(response_value, "iter_content") as iter_content: iter_content.return_value = iter(json_return_value) - response = client.chat_completions(request) + response = client.stream_generate_content(request) expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_chat_completions_rest_unset_required_fields(): +def test_stream_generate_content_rest_unset_required_fields(): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.chat_completions._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("endpoint",))) + unset_fields = transport.stream_generate_content._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "contents", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_chat_completions_rest_interceptors(null_interceptor): +def test_stream_generate_content_rest_interceptors(null_interceptor): transport = transports.PredictionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -6654,14 +7487,14 @@ def test_chat_completions_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.PredictionServiceRestInterceptor, "post_chat_completions" + transports.PredictionServiceRestInterceptor, "post_stream_generate_content" ) as post, mock.patch.object( - transports.PredictionServiceRestInterceptor, "pre_chat_completions" + transports.PredictionServiceRestInterceptor, "pre_stream_generate_content" ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = prediction_service.ChatCompletionsRequest.pb( - prediction_service.ChatCompletionsRequest() + pb_message = prediction_service.GenerateContentRequest.pb( + prediction_service.GenerateContentRequest() ) transcode.return_value = { "method": "post", @@ -6673,18 +7506,20 @@ def test_chat_completions_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = json_format.MessageToJson(httpbody_pb2.HttpBody()) + req.return_value._content = prediction_service.GenerateContentResponse.to_json( + prediction_service.GenerateContentResponse() + ) req.return_value._content = "[{}]".format(req.return_value._content) - request = prediction_service.ChatCompletionsRequest() + request = prediction_service.GenerateContentRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = httpbody_pb2.HttpBody() + post.return_value = prediction_service.GenerateContentResponse() - client.chat_completions( + client.stream_generate_content( request, metadata=[ ("key", "val"), @@ -6696,8 +7531,8 @@ def test_chat_completions_rest_interceptors(null_interceptor): post.assert_called_once() -def test_chat_completions_rest_bad_request( - transport: str = "rest", request_type=prediction_service.ChatCompletionsRequest +def test_stream_generate_content_rest_bad_request( + transport: str = "rest", request_type=prediction_service.GenerateContentRequest ): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6705,7 +7540,7 @@ def test_chat_completions_rest_bad_request( ) # send a request that will satisfy transcoding - request_init = {"endpoint": "projects/sample1/locations/sample2/endpoints/sample3"} + request_init = {"model": "projects/sample1/locations/sample2/endpoints/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -6717,10 +7552,10 @@ def test_chat_completions_rest_bad_request( response_value.status_code = 400 response_value.request = Request() req.return_value = response_value - client.chat_completions(request) + client.stream_generate_content(request) -def test_chat_completions_rest_flattened(): +def test_stream_generate_content_rest_flattened(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6729,23 +7564,25 @@ def test_chat_completions_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = httpbody_pb2.HttpBody() + return_value = prediction_service.GenerateContentResponse() # get arguments that satisfy an http rule for this method sample_request = { - "endpoint": "projects/sample1/locations/sample2/endpoints/sample3" + "model": "projects/sample1/locations/sample2/endpoints/sample3" } # get truthy value for each flattened field mock_args = dict( - endpoint="endpoint_value", - http_body=httpbody_pb2.HttpBody(content_type="content_type_value"), + model="model_value", + contents=[content.Content(role="role_value")], ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = prediction_service.GenerateContentResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") @@ -6753,20 +7590,20 @@ def test_chat_completions_rest_flattened(): with mock.patch.object(response_value, "iter_content") as iter_content: iter_content.return_value = iter(json_return_value) - client.chat_completions(**mock_args) + client.stream_generate_content(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{endpoint=projects/*/locations/*/endpoints/*}/chat/completions" + "%s/v1beta1/{model=projects/*/locations/*/endpoints/*}:streamGenerateContent" % client.transport._host, args[1], ) -def test_chat_completions_rest_flattened_error(transport: str = "rest"): +def test_stream_generate_content_rest_flattened_error(transport: str = "rest"): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6775,14 +7612,14 @@ def test_chat_completions_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.chat_completions( - prediction_service.ChatCompletionsRequest(), - endpoint="endpoint_value", - http_body=httpbody_pb2.HttpBody(content_type="content_type_value"), + client.stream_generate_content( + prediction_service.GenerateContentRequest(), + model="model_value", + contents=[content.Content(role="role_value")], ) -def test_chat_completions_rest_error(): +def test_stream_generate_content_rest_error(): client = PredictionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -6992,7 +7829,6 @@ def test_prediction_service_base_transport(): "count_tokens", "generate_content", "stream_generate_content", - "chat_completions", "set_iam_policy", "get_iam_policy", "test_iam_permissions", @@ -7300,9 +8136,6 @@ def test_prediction_service_client_transport_session_collision(transport_name): session1 = client1.transport.stream_generate_content._session session2 = client2.transport.stream_generate_content._session assert session1 != session2 - session1 = client1.transport.chat_completions._session - session2 = client2.transport.chat_completions._session - assert session1 != session2 def test_prediction_service_grpc_transport_channel(): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_execution_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_execution_service.py index e478d0c933..71b781b7c7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_execution_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_execution_service.py @@ -1285,6 +1285,9 @@ def test_query_reasoning_engine_empty_call(): with mock.patch.object( type(client.transport.query_reasoning_engine), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_reasoning_engine() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1312,6 +1315,9 @@ def test_query_reasoning_engine_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.query_reasoning_engine), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.query_reasoning_engine(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1322,6 +1328,46 @@ def test_query_reasoning_engine_non_empty_request_with_auto_populated_field(): ) +def test_query_reasoning_engine_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineExecutionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_reasoning_engine + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_reasoning_engine + ] = mock_rpc + request = {} + client.query_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_query_reasoning_engine_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1347,6 +1393,52 @@ async def test_query_reasoning_engine_empty_call_async(): ) +@pytest.mark.asyncio +async def test_query_reasoning_engine_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ReasoningEngineExecutionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_reasoning_engine + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.query_reasoning_engine + ] = mock_object + + request = {} + await client.query_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.query_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_query_reasoning_engine_async( transport: str = "grpc_asyncio", @@ -1500,6 +1592,47 @@ def test_query_reasoning_engine_rest(request_type): ) +def test_query_reasoning_engine_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineExecutionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.query_reasoning_engine + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.query_reasoning_engine + ] = mock_rpc + + request = {} + client.query_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_query_reasoning_engine_rest_required_fields( request_type=reasoning_engine_execution_service.QueryReasoningEngineRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_service.py index 0778e5cfe2..fe9c3b6f7a 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_service.py @@ -1259,6 +1259,9 @@ def test_create_reasoning_engine_empty_call(): with mock.patch.object( type(client.transport.create_reasoning_engine), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_reasoning_engine() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1284,6 +1287,9 @@ def test_create_reasoning_engine_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_reasoning_engine), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_reasoning_engine(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1292,6 +1298,50 @@ def test_create_reasoning_engine_non_empty_request_with_auto_populated_field(): ) +def test_create_reasoning_engine_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_reasoning_engine + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_reasoning_engine + ] = mock_rpc + request = {} + client.create_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_reasoning_engine_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1315,6 +1365,56 @@ async def test_create_reasoning_engine_empty_call_async(): assert args[0] == reasoning_engine_service.CreateReasoningEngineRequest() +@pytest.mark.asyncio +async def test_create_reasoning_engine_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_reasoning_engine + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_reasoning_engine + ] = mock_object + + request = {} + await client.create_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_reasoning_engine_async( transport: str = "grpc_asyncio", @@ -1571,6 +1671,9 @@ def test_get_reasoning_engine_empty_call(): with mock.patch.object( type(client.transport.get_reasoning_engine), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_reasoning_engine() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1596,6 +1699,9 @@ def test_get_reasoning_engine_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_reasoning_engine), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_reasoning_engine(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1604,6 +1710,45 @@ def test_get_reasoning_engine_non_empty_request_with_auto_populated_field(): ) +def test_get_reasoning_engine_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_reasoning_engine in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_reasoning_engine + ] = mock_rpc + request = {} + client.get_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_reasoning_engine_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1632,6 +1777,52 @@ async def test_get_reasoning_engine_empty_call_async(): assert args[0] == reasoning_engine_service.GetReasoningEngineRequest() +@pytest.mark.asyncio +async def test_get_reasoning_engine_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_reasoning_engine + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_reasoning_engine + ] = mock_object + + request = {} + await client.get_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_reasoning_engine_async( transport: str = "grpc_asyncio", @@ -1881,6 +2072,9 @@ def test_list_reasoning_engines_empty_call(): with mock.patch.object( type(client.transport.list_reasoning_engines), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_reasoning_engines() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1908,6 +2102,9 @@ def test_list_reasoning_engines_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_reasoning_engines), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_reasoning_engines(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1918,6 +2115,46 @@ def test_list_reasoning_engines_non_empty_request_with_auto_populated_field(): ) +def test_list_reasoning_engines_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_reasoning_engines + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_reasoning_engines + ] = mock_rpc + request = {} + client.list_reasoning_engines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_reasoning_engines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_reasoning_engines_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1943,6 +2180,52 @@ async def test_list_reasoning_engines_empty_call_async(): assert args[0] == reasoning_engine_service.ListReasoningEnginesRequest() +@pytest.mark.asyncio +async def test_list_reasoning_engines_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_reasoning_engines + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_reasoning_engines + ] = mock_object + + request = {} + await client.list_reasoning_engines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_reasoning_engines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_reasoning_engines_async( transport: str = "grpc_asyncio", @@ -2381,6 +2664,9 @@ def test_delete_reasoning_engine_empty_call(): with mock.patch.object( type(client.transport.delete_reasoning_engine), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_reasoning_engine() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2406,6 +2692,9 @@ def test_delete_reasoning_engine_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_reasoning_engine), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_reasoning_engine(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2414,6 +2703,50 @@ def test_delete_reasoning_engine_non_empty_request_with_auto_populated_field(): ) +def test_delete_reasoning_engine_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_reasoning_engine + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_reasoning_engine + ] = mock_rpc + request = {} + client.delete_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_reasoning_engine_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2437,6 +2770,56 @@ async def test_delete_reasoning_engine_empty_call_async(): assert args[0] == reasoning_engine_service.DeleteReasoningEngineRequest() +@pytest.mark.asyncio +async def test_delete_reasoning_engine_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_reasoning_engine + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_reasoning_engine + ] = mock_object + + request = {} + await client.delete_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_reasoning_engine_async( transport: str = "grpc_asyncio", @@ -2748,6 +3131,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_reasoning_engine_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_reasoning_engine + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_reasoning_engine + ] = mock_rpc + + request = {} + client.create_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_reasoning_engine_rest_required_fields( request_type=reasoning_engine_service.CreateReasoningEngineRequest, ): @@ -3032,6 +3460,46 @@ def test_get_reasoning_engine_rest(request_type): assert response.etag == "etag_value" +def test_get_reasoning_engine_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_reasoning_engine in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_reasoning_engine + ] = mock_rpc + + request = {} + client.get_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_reasoning_engine_rest_required_fields( request_type=reasoning_engine_service.GetReasoningEngineRequest, ): @@ -3306,6 +3774,47 @@ def test_list_reasoning_engines_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_reasoning_engines_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_reasoning_engines + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_reasoning_engines + ] = mock_rpc + + request = {} + client.list_reasoning_engines(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_reasoning_engines(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_reasoning_engines_rest_required_fields( request_type=reasoning_engine_service.ListReasoningEnginesRequest, ): @@ -3652,6 +4161,51 @@ def test_delete_reasoning_engine_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_reasoning_engine_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ReasoningEngineServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_reasoning_engine + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_reasoning_engine + ] = mock_rpc + + request = {} + client.delete_reasoning_engine(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_reasoning_engine(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_reasoning_engine_rest_required_fields( request_type=reasoning_engine_service.DeleteReasoningEngineRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py index 7f0248bcd3..1c38a0bc1d 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py @@ -71,6 +71,8 @@ from google.cloud.aiplatform_v1beta1.types import model_monitoring_job from google.cloud.aiplatform_v1beta1.types import model_monitoring_service from google.cloud.aiplatform_v1beta1.types import model_monitoring_spec +from google.cloud.aiplatform_v1beta1.types import notebook_execution_job +from google.cloud.aiplatform_v1beta1.types import notebook_service from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import pipeline_failure_policy from google.cloud.aiplatform_v1beta1.types import pipeline_job @@ -87,6 +89,7 @@ from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account from google.protobuf import any_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import struct_pb2 # type: ignore @@ -1246,6 +1249,9 @@ def test_create_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1269,6 +1275,9 @@ def test_create_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1277,6 +1286,41 @@ def test_create_schedule_non_empty_request_with_auto_populated_field(): ) +def test_create_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_schedule] = mock_rpc + request = {} + client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1307,6 +1351,52 @@ async def test_create_schedule_empty_call_async(): assert args[0] == schedule_service.CreateScheduleRequest() +@pytest.mark.asyncio +async def test_create_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_schedule + ] = mock_object + + request = {} + await client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.CreateScheduleRequest @@ -1556,6 +1646,9 @@ def test_delete_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1579,6 +1672,9 @@ def test_delete_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1587,6 +1683,45 @@ def test_delete_schedule_non_empty_request_with_auto_populated_field(): ) +def test_delete_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_schedule] = mock_rpc + request = {} + client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1608,6 +1743,56 @@ async def test_delete_schedule_empty_call_async(): assert args[0] == schedule_service.DeleteScheduleRequest() +@pytest.mark.asyncio +async def test_delete_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_schedule + ] = mock_object + + request = {} + await client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.DeleteScheduleRequest @@ -1848,6 +2033,9 @@ def test_get_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1871,6 +2059,9 @@ def test_get_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1879,6 +2070,41 @@ def test_get_schedule_non_empty_request_with_auto_populated_field(): ) +def test_get_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_schedule] = mock_rpc + request = {} + client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1909,6 +2135,52 @@ async def test_get_schedule_empty_call_async(): assert args[0] == schedule_service.GetScheduleRequest() +@pytest.mark.asyncio +async def test_get_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_schedule + ] = mock_object + + request = {} + await client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.GetScheduleRequest @@ -2147,6 +2419,9 @@ def test_list_schedules_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_schedules() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2173,6 +2448,9 @@ def test_list_schedules_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_schedules(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2184,6 +2462,41 @@ def test_list_schedules_non_empty_request_with_auto_populated_field(): ) +def test_list_schedules_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_schedules in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_schedules] = mock_rpc + request = {} + client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_schedules(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_schedules_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2207,6 +2520,52 @@ async def test_list_schedules_empty_call_async(): assert args[0] == schedule_service.ListSchedulesRequest() +@pytest.mark.asyncio +async def test_list_schedules_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_schedules + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_schedules + ] = mock_object + + request = {} + await client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_schedules(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_schedules_async( transport: str = "grpc_asyncio", request_type=schedule_service.ListSchedulesRequest @@ -2622,6 +2981,9 @@ def test_pause_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.pause_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2645,6 +3007,9 @@ def test_pause_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.pause_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2653,6 +3018,41 @@ def test_pause_schedule_non_empty_request_with_auto_populated_field(): ) +def test_pause_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.pause_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.pause_schedule] = mock_rpc + request = {} + client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.pause_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_pause_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2672,6 +3072,52 @@ async def test_pause_schedule_empty_call_async(): assert args[0] == schedule_service.PauseScheduleRequest() +@pytest.mark.asyncio +async def test_pause_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.pause_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.pause_schedule + ] = mock_object + + request = {} + await client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.pause_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_pause_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.PauseScheduleRequest @@ -2888,6 +3334,9 @@ def test_resume_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.resume_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2911,6 +3360,9 @@ def test_resume_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.resume_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2919,6 +3371,41 @@ def test_resume_schedule_non_empty_request_with_auto_populated_field(): ) +def test_resume_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.resume_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.resume_schedule] = mock_rpc + request = {} + client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.resume_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_resume_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2938,6 +3425,52 @@ async def test_resume_schedule_empty_call_async(): assert args[0] == schedule_service.ResumeScheduleRequest() +@pytest.mark.asyncio +async def test_resume_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.resume_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.resume_schedule + ] = mock_object + + request = {} + await client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.resume_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_resume_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.ResumeScheduleRequest @@ -3182,6 +3715,9 @@ def test_update_schedule_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_schedule() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3203,12 +3739,50 @@ def test_update_schedule_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_schedule(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == schedule_service.UpdateScheduleRequest() +def test_update_schedule_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_schedule] = mock_rpc + request = {} + client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_schedule_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3239,6 +3813,52 @@ async def test_update_schedule_empty_call_async(): assert args[0] == schedule_service.UpdateScheduleRequest() +@pytest.mark.asyncio +async def test_update_schedule_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_schedule + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_schedule + ] = mock_object + + request = {} + await client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_schedule_async( transport: str = "grpc_asyncio", request_type=schedule_service.UpdateScheduleRequest @@ -3718,6 +4338,32 @@ def test_create_schedule_rest(request_type): }, "model_monitoring_job_id": "model_monitoring_job_id_value", }, + "create_notebook_execution_job_request": { + "parent": "parent_value", + "notebook_execution_job": { + "dataform_repository_source": { + "dataform_repository_resource_name": "dataform_repository_resource_name_value", + "commit_sha": "commit_sha_value", + }, + "gcs_notebook_source": { + "uri": "uri_value", + "generation": "generation_value", + }, + "notebook_runtime_template_resource_name": "notebook_runtime_template_resource_name_value", + "gcs_output_uri": "gcs_output_uri_value", + "execution_user": "execution_user_value", + "service_account": "service_account_value", + "name": "name_value", + "display_name": "display_name_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + "schedule_resource_name": "schedule_resource_name_value", + "job_state": 1, + "status": {}, + "create_time": {}, + "update_time": {}, + }, + "notebook_execution_job_id": "notebook_execution_job_id_value", + }, "name": "name_value", "display_name": "display_name_value", "start_time": {}, @@ -3845,6 +4491,42 @@ def get_message_fields(field): assert response.catch_up is True +def test_create_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_schedule] = mock_rpc + + request = {} + client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_schedule_rest_required_fields( request_type=schedule_service.CreateScheduleRequest, ): @@ -4118,6 +4800,46 @@ def test_delete_schedule_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_schedule] = mock_rpc + + request = {} + client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_schedule_rest_required_fields( request_type=schedule_service.DeleteScheduleRequest, ): @@ -4399,6 +5121,42 @@ def test_get_schedule_rest(request_type): assert response.catch_up is True +def test_get_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_schedule] = mock_rpc + + request = {} + client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_schedule_rest_required_fields( request_type=schedule_service.GetScheduleRequest, ): @@ -4666,6 +5424,42 @@ def test_list_schedules_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_schedules_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_schedules in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_schedules] = mock_rpc + + request = {} + client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_schedules(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_schedules_rest_required_fields( request_type=schedule_service.ListSchedulesRequest, ): @@ -5004,6 +5798,42 @@ def test_pause_schedule_rest(request_type): assert response is None +def test_pause_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.pause_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.pause_schedule] = mock_rpc + + request = {} + client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.pause_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_pause_schedule_rest_required_fields( request_type=schedule_service.PauseScheduleRequest, ): @@ -5256,6 +6086,42 @@ def test_resume_schedule_rest(request_type): assert response is None +def test_resume_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.resume_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.resume_schedule] = mock_rpc + + request = {} + client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.resume_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_resume_schedule_rest_required_fields( request_type=schedule_service.ResumeScheduleRequest, ): @@ -5750,6 +6616,32 @@ def test_update_schedule_rest(request_type): }, "model_monitoring_job_id": "model_monitoring_job_id_value", }, + "create_notebook_execution_job_request": { + "parent": "parent_value", + "notebook_execution_job": { + "dataform_repository_source": { + "dataform_repository_resource_name": "dataform_repository_resource_name_value", + "commit_sha": "commit_sha_value", + }, + "gcs_notebook_source": { + "uri": "uri_value", + "generation": "generation_value", + }, + "notebook_runtime_template_resource_name": "notebook_runtime_template_resource_name_value", + "gcs_output_uri": "gcs_output_uri_value", + "execution_user": "execution_user_value", + "service_account": "service_account_value", + "name": "name_value", + "display_name": "display_name_value", + "execution_timeout": {"seconds": 751, "nanos": 543}, + "schedule_resource_name": "schedule_resource_name_value", + "job_state": 1, + "status": {}, + "create_time": {}, + "update_time": {}, + }, + "notebook_execution_job_id": "notebook_execution_job_id_value", + }, "name": "projects/sample1/locations/sample2/schedules/sample3", "display_name": "display_name_value", "start_time": {}, @@ -5877,6 +6769,42 @@ def get_message_fields(field): assert response.catch_up is True +def test_update_schedule_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_schedule in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_schedule] = mock_rpc + + request = {} + client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_schedule(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_schedule_rest_required_fields( request_type=schedule_service.UpdateScheduleRequest, ): @@ -7016,6 +7944,62 @@ def test_parse_network_path(): assert expected == actual +def test_notebook_execution_job_path(): + project = "cuttlefish" + location = "mussel" + notebook_execution_job = "winkle" + expected = "projects/{project}/locations/{location}/notebookExecutionJobs/{notebook_execution_job}".format( + project=project, + location=location, + notebook_execution_job=notebook_execution_job, + ) + actual = ScheduleServiceClient.notebook_execution_job_path( + project, location, notebook_execution_job + ) + assert expected == actual + + +def test_parse_notebook_execution_job_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "notebook_execution_job": "abalone", + } + path = ScheduleServiceClient.notebook_execution_job_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_notebook_execution_job_path(path) + assert expected == actual + + +def test_notebook_runtime_template_path(): + project = "squid" + location = "clam" + notebook_runtime_template = "whelk" + expected = "projects/{project}/locations/{location}/notebookRuntimeTemplates/{notebook_runtime_template}".format( + project=project, + location=location, + notebook_runtime_template=notebook_runtime_template, + ) + actual = ScheduleServiceClient.notebook_runtime_template_path( + project, location, notebook_runtime_template + ) + assert expected == actual + + +def test_parse_notebook_runtime_template_path(): + expected = { + "project": "octopus", + "location": "oyster", + "notebook_runtime_template": "nudibranch", + } + path = ScheduleServiceClient.notebook_runtime_template_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_notebook_runtime_template_path(path) + assert expected == actual + + def test_pipeline_job_path(): project = "cuttlefish" location = "mussel" diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index 3ad057116e..689cf9e3ca 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -1254,6 +1254,9 @@ def test_create_specialist_pool_empty_call(): with mock.patch.object( type(client.transport.create_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1279,6 +1282,9 @@ def test_create_specialist_pool_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_specialist_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1287,6 +1293,50 @@ def test_create_specialist_pool_non_empty_request_with_auto_populated_field(): ) +def test_create_specialist_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_specialist_pool + ] = mock_rpc + request = {} + client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_specialist_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1310,6 +1360,56 @@ async def test_create_specialist_pool_empty_call_async(): assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() +@pytest.mark.asyncio +async def test_create_specialist_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_specialist_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_specialist_pool + ] = mock_object + + request = {} + await client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_specialist_pool_async( transport: str = "grpc_asyncio", @@ -1570,6 +1670,9 @@ def test_get_specialist_pool_empty_call(): with mock.patch.object( type(client.transport.get_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1595,6 +1698,9 @@ def test_get_specialist_pool_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_specialist_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1603,6 +1709,45 @@ def test_get_specialist_pool_non_empty_request_with_auto_populated_field(): ) +def test_get_specialist_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_specialist_pool in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_specialist_pool + ] = mock_rpc + request = {} + client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_specialist_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1633,6 +1778,52 @@ async def test_get_specialist_pool_empty_call_async(): assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() +@pytest.mark.asyncio +async def test_get_specialist_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_specialist_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_specialist_pool + ] = mock_object + + request = {} + await client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_specialist_pool_async( transport: str = "grpc_asyncio", @@ -1886,6 +2077,9 @@ def test_list_specialist_pools_empty_call(): with mock.patch.object( type(client.transport.list_specialist_pools), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_specialist_pools() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1912,6 +2106,9 @@ def test_list_specialist_pools_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_specialist_pools), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_specialist_pools(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1921,6 +2118,46 @@ def test_list_specialist_pools_non_empty_request_with_auto_populated_field(): ) +def test_list_specialist_pools_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_specialist_pools + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_specialist_pools + ] = mock_rpc + request = {} + client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_specialist_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_specialist_pools_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1946,6 +2183,52 @@ async def test_list_specialist_pools_empty_call_async(): assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() +@pytest.mark.asyncio +async def test_list_specialist_pools_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_specialist_pools + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_specialist_pools + ] = mock_object + + request = {} + await client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_specialist_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_specialist_pools_async( transport: str = "grpc_asyncio", @@ -2384,6 +2667,9 @@ def test_delete_specialist_pool_empty_call(): with mock.patch.object( type(client.transport.delete_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2409,6 +2695,9 @@ def test_delete_specialist_pool_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_specialist_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2417,6 +2706,50 @@ def test_delete_specialist_pool_non_empty_request_with_auto_populated_field(): ) +def test_delete_specialist_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_specialist_pool + ] = mock_rpc + request = {} + client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_specialist_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2440,6 +2773,56 @@ async def test_delete_specialist_pool_empty_call_async(): assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() +@pytest.mark.asyncio +async def test_delete_specialist_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_specialist_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_specialist_pool + ] = mock_object + + request = {} + await client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_specialist_pool_async( transport: str = "grpc_asyncio", @@ -2677,6 +3060,9 @@ def test_update_specialist_pool_empty_call(): with mock.patch.object( type(client.transport.update_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_specialist_pool() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2700,12 +3086,59 @@ def test_update_specialist_pool_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_specialist_pool), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_specialist_pool(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() +def test_update_specialist_pool_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_specialist_pool + ] = mock_rpc + request = {} + client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_specialist_pool_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2729,6 +3162,56 @@ async def test_update_specialist_pool_empty_call_async(): assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() +@pytest.mark.asyncio +async def test_update_specialist_pool_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_specialist_pool + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_specialist_pool + ] = mock_object + + request = {} + await client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_specialist_pool_async( transport: str = "grpc_asyncio", @@ -3050,6 +3533,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_specialist_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_specialist_pool + ] = mock_rpc + + request = {} + client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_specialist_pool_rest_required_fields( request_type=specialist_pool_service.CreateSpecialistPoolRequest, ): @@ -3338,6 +3866,46 @@ def test_get_specialist_pool_rest(request_type): assert response.specialist_worker_emails == ["specialist_worker_emails_value"] +def test_get_specialist_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_specialist_pool in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_specialist_pool + ] = mock_rpc + + request = {} + client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_specialist_pool_rest_required_fields( request_type=specialist_pool_service.GetSpecialistPoolRequest, ): @@ -3612,6 +4180,47 @@ def test_list_specialist_pools_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_specialist_pools_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_specialist_pools + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_specialist_pools + ] = mock_rpc + + request = {} + client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_specialist_pools(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_specialist_pools_rest_required_fields( request_type=specialist_pool_service.ListSpecialistPoolsRequest, ): @@ -3958,6 +4567,51 @@ def test_delete_specialist_pool_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_specialist_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_specialist_pool + ] = mock_rpc + + request = {} + client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_specialist_pool_rest_required_fields( request_type=specialist_pool_service.DeleteSpecialistPoolRequest, ): @@ -4314,6 +4968,51 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_specialist_pool_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = SpecialistPoolServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_specialist_pool + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_specialist_pool + ] = mock_rpc + + request = {} + client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_specialist_pool(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_specialist_pool_rest_required_fields( request_type=specialist_pool_service.UpdateSpecialistPoolRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py index 43cd132a16..1a1159a0e7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py @@ -1239,6 +1239,9 @@ def test_create_tensorboard_empty_call(): with mock.patch.object( type(client.transport.create_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1264,6 +1267,9 @@ def test_create_tensorboard_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1272,6 +1278,49 @@ def test_create_tensorboard_non_empty_request_with_auto_populated_field(): ) +def test_create_tensorboard_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard + ] = mock_rpc + request = {} + client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1295,6 +1344,56 @@ async def test_create_tensorboard_empty_call_async(): assert args[0] == tensorboard_service.CreateTensorboardRequest() +@pytest.mark.asyncio +async def test_create_tensorboard_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tensorboard + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tensorboard + ] = mock_object + + request = {} + await client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_async( transport: str = "grpc_asyncio", @@ -1553,6 +1652,9 @@ def test_get_tensorboard_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1576,6 +1678,9 @@ def test_get_tensorboard_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_tensorboard), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1584,6 +1689,41 @@ def test_get_tensorboard_non_empty_request_with_auto_populated_field(): ) +def test_get_tensorboard_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_tensorboard in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_tensorboard] = mock_rpc + request = {} + client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1613,6 +1753,52 @@ async def test_get_tensorboard_empty_call_async(): assert args[0] == tensorboard_service.GetTensorboardRequest() +@pytest.mark.asyncio +async def test_get_tensorboard_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tensorboard + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tensorboard + ] = mock_object + + request = {} + await client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_async( transport: str = "grpc_asyncio", @@ -1855,6 +2041,9 @@ def test_update_tensorboard_empty_call(): with mock.patch.object( type(client.transport.update_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1878,12 +2067,58 @@ def test_update_tensorboard_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == tensorboard_service.UpdateTensorboardRequest() +def test_update_tensorboard_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard + ] = mock_rpc + request = {} + client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1907,6 +2142,56 @@ async def test_update_tensorboard_empty_call_async(): assert args[0] == tensorboard_service.UpdateTensorboardRequest() +@pytest.mark.asyncio +async def test_update_tensorboard_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tensorboard + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tensorboard + ] = mock_object + + request = {} + await client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_async( transport: str = "grpc_asyncio", @@ -2157,6 +2442,9 @@ def test_list_tensorboards_empty_call(): with mock.patch.object( type(client.transport.list_tensorboards), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboards() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2185,6 +2473,9 @@ def test_list_tensorboards_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_tensorboards), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboards(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2196,6 +2487,43 @@ def test_list_tensorboards_non_empty_request_with_auto_populated_field(): ) +def test_list_tensorboards_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_tensorboards in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboards + ] = mock_rpc + request = {} + client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboards(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboards_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2221,6 +2549,52 @@ async def test_list_tensorboards_empty_call_async(): assert args[0] == tensorboard_service.ListTensorboardsRequest() +@pytest.mark.asyncio +async def test_list_tensorboards_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tensorboards + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tensorboards + ] = mock_object + + request = {} + await client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tensorboards(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboards_async( transport: str = "grpc_asyncio", @@ -2659,6 +3033,9 @@ def test_delete_tensorboard_empty_call(): with mock.patch.object( type(client.transport.delete_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2684,6 +3061,9 @@ def test_delete_tensorboard_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_tensorboard), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2692,6 +3072,49 @@ def test_delete_tensorboard_non_empty_request_with_auto_populated_field(): ) +def test_delete_tensorboard_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard + ] = mock_rpc + request = {} + client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2715,6 +3138,56 @@ async def test_delete_tensorboard_empty_call_async(): assert args[0] == tensorboard_service.DeleteTensorboardRequest() +@pytest.mark.asyncio +async def test_delete_tensorboard_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tensorboard + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tensorboard + ] = mock_object + + request = {} + await client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_async( transport: str = "grpc_asyncio", @@ -2952,6 +3425,9 @@ def test_read_tensorboard_usage_empty_call(): with mock.patch.object( type(client.transport.read_tensorboard_usage), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_usage() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2977,6 +3453,9 @@ def test_read_tensorboard_usage_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.read_tensorboard_usage), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_usage(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2985,6 +3464,46 @@ def test_read_tensorboard_usage_non_empty_request_with_auto_populated_field(): ) +def test_read_tensorboard_usage_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_usage + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_usage + ] = mock_rpc + request = {} + client.read_tensorboard_usage(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_usage(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_usage_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3008,6 +3527,52 @@ async def test_read_tensorboard_usage_empty_call_async(): assert args[0] == tensorboard_service.ReadTensorboardUsageRequest() +@pytest.mark.asyncio +async def test_read_tensorboard_usage_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_tensorboard_usage + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_tensorboard_usage + ] = mock_object + + request = {} + await client.read_tensorboard_usage(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_tensorboard_usage(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_usage_async( transport: str = "grpc_asyncio", @@ -3248,6 +3813,9 @@ def test_read_tensorboard_size_empty_call(): with mock.patch.object( type(client.transport.read_tensorboard_size), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_size() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3273,6 +3841,9 @@ def test_read_tensorboard_size_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.read_tensorboard_size), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_size(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3281,6 +3852,46 @@ def test_read_tensorboard_size_non_empty_request_with_auto_populated_field(): ) +def test_read_tensorboard_size_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_size + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_size + ] = mock_rpc + request = {} + client.read_tensorboard_size(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_size(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_size_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3306,6 +3917,52 @@ async def test_read_tensorboard_size_empty_call_async(): assert args[0] == tensorboard_service.ReadTensorboardSizeRequest() +@pytest.mark.asyncio +async def test_read_tensorboard_size_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_tensorboard_size + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_tensorboard_size + ] = mock_object + + request = {} + await client.read_tensorboard_size(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_tensorboard_size(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_size_async( transport: str = "grpc_asyncio", @@ -3557,6 +4214,9 @@ def test_create_tensorboard_experiment_empty_call(): with mock.patch.object( type(client.transport.create_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_experiment() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3583,6 +4243,9 @@ def test_create_tensorboard_experiment_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.create_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_experiment(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3592,6 +4255,46 @@ def test_create_tensorboard_experiment_non_empty_request_with_auto_populated_fie ) +def test_create_tensorboard_experiment_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_experiment + ] = mock_rpc + request = {} + client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_experiment_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3621,6 +4324,52 @@ async def test_create_tensorboard_experiment_empty_call_async(): assert args[0] == tensorboard_service.CreateTensorboardExperimentRequest() +@pytest.mark.asyncio +async def test_create_tensorboard_experiment_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tensorboard_experiment + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tensorboard_experiment + ] = mock_object + + request = {} + await client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_experiment_async( transport: str = "grpc_asyncio", @@ -3908,6 +4657,9 @@ def test_get_tensorboard_experiment_empty_call(): with mock.patch.object( type(client.transport.get_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_experiment() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3933,6 +4685,9 @@ def test_get_tensorboard_experiment_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.get_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_experiment(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3941,6 +4696,46 @@ def test_get_tensorboard_experiment_non_empty_request_with_auto_populated_field( ) +def test_get_tensorboard_experiment_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_experiment + ] = mock_rpc + request = {} + client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_experiment_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3970,6 +4765,52 @@ async def test_get_tensorboard_experiment_empty_call_async(): assert args[0] == tensorboard_service.GetTensorboardExperimentRequest() +@pytest.mark.asyncio +async def test_get_tensorboard_experiment_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tensorboard_experiment + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tensorboard_experiment + ] = mock_object + + request = {} + await client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_experiment_async( transport: str = "grpc_asyncio", @@ -4229,6 +5070,9 @@ def test_update_tensorboard_experiment_empty_call(): with mock.patch.object( type(client.transport.update_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_experiment() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4252,12 +5096,55 @@ def test_update_tensorboard_experiment_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.update_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_experiment(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == tensorboard_service.UpdateTensorboardExperimentRequest() +def test_update_tensorboard_experiment_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_experiment + ] = mock_rpc + request = {} + client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_experiment_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4287,6 +5174,52 @@ async def test_update_tensorboard_experiment_empty_call_async(): assert args[0] == tensorboard_service.UpdateTensorboardExperimentRequest() +@pytest.mark.asyncio +async def test_update_tensorboard_experiment_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tensorboard_experiment + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tensorboard_experiment + ] = mock_object + + request = {} + await client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_experiment_async( transport: str = "grpc_asyncio", @@ -4556,6 +5489,9 @@ def test_list_tensorboard_experiments_empty_call(): with mock.patch.object( type(client.transport.list_tensorboard_experiments), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_experiments() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4584,6 +5520,9 @@ def test_list_tensorboard_experiments_non_empty_request_with_auto_populated_fiel with mock.patch.object( type(client.transport.list_tensorboard_experiments), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_experiments(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4595,6 +5534,46 @@ def test_list_tensorboard_experiments_non_empty_request_with_auto_populated_fiel ) +def test_list_tensorboard_experiments_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_experiments + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_experiments + ] = mock_rpc + request = {} + client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_experiments(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_experiments_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4620,6 +5599,52 @@ async def test_list_tensorboard_experiments_empty_call_async(): assert args[0] == tensorboard_service.ListTensorboardExperimentsRequest() +@pytest.mark.asyncio +async def test_list_tensorboard_experiments_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tensorboard_experiments + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tensorboard_experiments + ] = mock_object + + request = {} + await client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tensorboard_experiments(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_experiments_async( transport: str = "grpc_asyncio", @@ -5063,6 +6088,9 @@ def test_delete_tensorboard_experiment_empty_call(): with mock.patch.object( type(client.transport.delete_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_experiment() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5088,6 +6116,9 @@ def test_delete_tensorboard_experiment_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.delete_tensorboard_experiment), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_experiment(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5096,6 +6127,50 @@ def test_delete_tensorboard_experiment_non_empty_request_with_auto_populated_fie ) +def test_delete_tensorboard_experiment_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_experiment + ] = mock_rpc + request = {} + client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_experiment_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5113,10 +6188,60 @@ async def test_delete_tensorboard_experiment_empty_call_async(): call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") ) - response = await client.delete_tensorboard_experiment() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == tensorboard_service.DeleteTensorboardExperimentRequest() + response = await client.delete_tensorboard_experiment() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == tensorboard_service.DeleteTensorboardExperimentRequest() + + +@pytest.mark.asyncio +async def test_delete_tensorboard_experiment_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tensorboard_experiment + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tensorboard_experiment + ] = mock_object + + request = {} + await client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 @pytest.mark.asyncio @@ -5365,6 +6490,9 @@ def test_create_tensorboard_run_empty_call(): with mock.patch.object( type(client.transport.create_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_run() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5391,6 +6519,9 @@ def test_create_tensorboard_run_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_run(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5400,6 +6531,46 @@ def test_create_tensorboard_run_non_empty_request_with_auto_populated_field(): ) +def test_create_tensorboard_run_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_run + ] = mock_rpc + request = {} + client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_run_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5428,6 +6599,52 @@ async def test_create_tensorboard_run_empty_call_async(): assert args[0] == tensorboard_service.CreateTensorboardRunRequest() +@pytest.mark.asyncio +async def test_create_tensorboard_run_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tensorboard_run + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tensorboard_run + ] = mock_object + + request = {} + await client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_run_async( transport: str = "grpc_asyncio", @@ -5694,6 +6911,9 @@ def test_batch_create_tensorboard_runs_empty_call(): with mock.patch.object( type(client.transport.batch_create_tensorboard_runs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_tensorboard_runs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5719,6 +6939,9 @@ def test_batch_create_tensorboard_runs_non_empty_request_with_auto_populated_fie with mock.patch.object( type(client.transport.batch_create_tensorboard_runs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_tensorboard_runs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5727,6 +6950,46 @@ def test_batch_create_tensorboard_runs_non_empty_request_with_auto_populated_fie ) +def test_batch_create_tensorboard_runs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_tensorboard_runs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_tensorboard_runs + ] = mock_rpc + request = {} + client.batch_create_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_tensorboard_runs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5750,6 +7013,52 @@ async def test_batch_create_tensorboard_runs_empty_call_async(): assert args[0] == tensorboard_service.BatchCreateTensorboardRunsRequest() +@pytest.mark.asyncio +async def test_batch_create_tensorboard_runs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_create_tensorboard_runs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_create_tensorboard_runs + ] = mock_object + + request = {} + await client.batch_create_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_create_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_tensorboard_runs_async( transport: str = "grpc_asyncio", @@ -6018,6 +7327,9 @@ def test_get_tensorboard_run_empty_call(): with mock.patch.object( type(client.transport.get_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_run() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6043,6 +7355,9 @@ def test_get_tensorboard_run_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.get_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_run(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6051,6 +7366,45 @@ def test_get_tensorboard_run_non_empty_request_with_auto_populated_field(): ) +def test_get_tensorboard_run_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_run in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_run + ] = mock_rpc + request = {} + client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_run_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6079,6 +7433,52 @@ async def test_get_tensorboard_run_empty_call_async(): assert args[0] == tensorboard_service.GetTensorboardRunRequest() +@pytest.mark.asyncio +async def test_get_tensorboard_run_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tensorboard_run + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tensorboard_run + ] = mock_object + + request = {} + await client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_run_async( transport: str = "grpc_asyncio", @@ -6334,6 +7734,9 @@ def test_update_tensorboard_run_empty_call(): with mock.patch.object( type(client.transport.update_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_run() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6357,12 +7760,55 @@ def test_update_tensorboard_run_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.update_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_run(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == tensorboard_service.UpdateTensorboardRunRequest() +def test_update_tensorboard_run_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_run + ] = mock_rpc + request = {} + client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_run_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6391,6 +7837,52 @@ async def test_update_tensorboard_run_empty_call_async(): assert args[0] == tensorboard_service.UpdateTensorboardRunRequest() +@pytest.mark.asyncio +async def test_update_tensorboard_run_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tensorboard_run + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tensorboard_run + ] = mock_object + + request = {} + await client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_run_async( transport: str = "grpc_asyncio", @@ -6650,6 +8142,9 @@ def test_list_tensorboard_runs_empty_call(): with mock.patch.object( type(client.transport.list_tensorboard_runs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_runs() call.assert_called() _, args, _ = call.mock_calls[0] @@ -6678,6 +8173,9 @@ def test_list_tensorboard_runs_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_tensorboard_runs), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_runs(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -6689,6 +8187,46 @@ def test_list_tensorboard_runs_non_empty_request_with_auto_populated_field(): ) +def test_list_tensorboard_runs_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_runs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_runs + ] = mock_rpc + request = {} + client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_runs_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -6714,6 +8252,52 @@ async def test_list_tensorboard_runs_empty_call_async(): assert args[0] == tensorboard_service.ListTensorboardRunsRequest() +@pytest.mark.asyncio +async def test_list_tensorboard_runs_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tensorboard_runs + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tensorboard_runs + ] = mock_object + + request = {} + await client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_runs_async( transport: str = "grpc_asyncio", @@ -7152,6 +8736,9 @@ def test_delete_tensorboard_run_empty_call(): with mock.patch.object( type(client.transport.delete_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_run() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7177,6 +8764,9 @@ def test_delete_tensorboard_run_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_tensorboard_run), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_run(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7185,6 +8775,50 @@ def test_delete_tensorboard_run_non_empty_request_with_auto_populated_field(): ) +def test_delete_tensorboard_run_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_run + ] = mock_rpc + request = {} + client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_run_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7208,6 +8842,56 @@ async def test_delete_tensorboard_run_empty_call_async(): assert args[0] == tensorboard_service.DeleteTensorboardRunRequest() +@pytest.mark.asyncio +async def test_delete_tensorboard_run_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tensorboard_run + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tensorboard_run + ] = mock_object + + request = {} + await client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_run_async( transport: str = "grpc_asyncio", @@ -7449,6 +9133,9 @@ def test_batch_create_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.batch_create_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7474,6 +9161,9 @@ def test_batch_create_tensorboard_time_series_non_empty_request_with_auto_popula with mock.patch.object( type(client.transport.batch_create_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_create_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7482,6 +9172,46 @@ def test_batch_create_tensorboard_time_series_non_empty_request_with_auto_popula ) +def test_batch_create_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_tensorboard_time_series + ] = mock_rpc + request = {} + client.batch_create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7505,6 +9235,52 @@ async def test_batch_create_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.BatchCreateTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_batch_create_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_create_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_create_tensorboard_time_series + ] = mock_object + + request = {} + await client.batch_create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_create_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -7802,6 +9578,9 @@ def test_create_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.create_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -7828,6 +9607,9 @@ def test_create_tensorboard_time_series_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.create_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -7837,6 +9619,46 @@ def test_create_tensorboard_time_series_non_empty_request_with_auto_populated_fi ) +def test_create_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_time_series + ] = mock_rpc + request = {} + client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -7868,6 +9690,52 @@ async def test_create_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.CreateTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_create_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tensorboard_time_series + ] = mock_object + + request = {} + await client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -8159,6 +10027,9 @@ def test_get_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.get_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8184,6 +10055,9 @@ def test_get_tensorboard_time_series_non_empty_request_with_auto_populated_field with mock.patch.object( type(client.transport.get_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8192,6 +10066,46 @@ def test_get_tensorboard_time_series_non_empty_request_with_auto_populated_field ) +def test_get_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_time_series + ] = mock_rpc + request = {} + client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8223,6 +10137,52 @@ async def test_get_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.GetTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_get_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tensorboard_time_series + ] = mock_object + + request = {} + await client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -8496,6 +10456,9 @@ def test_update_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.update_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8519,12 +10482,55 @@ def test_update_tensorboard_time_series_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.update_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.update_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == tensorboard_service.UpdateTensorboardTimeSeriesRequest() +def test_update_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_time_series + ] = mock_rpc + request = {} + client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8556,6 +10562,52 @@ async def test_update_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.UpdateTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_update_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tensorboard_time_series + ] = mock_object + + request = {} + await client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.update_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_update_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -8832,6 +10884,9 @@ def test_list_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.list_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -8860,6 +10915,9 @@ def test_list_tensorboard_time_series_non_empty_request_with_auto_populated_fiel with mock.patch.object( type(client.transport.list_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -8871,6 +10929,46 @@ def test_list_tensorboard_time_series_non_empty_request_with_auto_populated_fiel ) +def test_list_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_time_series + ] = mock_rpc + request = {} + client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -8896,6 +10994,52 @@ async def test_list_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.ListTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_list_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tensorboard_time_series + ] = mock_object + + request = {} + await client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -9340,6 +11484,9 @@ def test_delete_tensorboard_time_series_empty_call(): with mock.patch.object( type(client.transport.delete_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_time_series() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9365,6 +11512,9 @@ def test_delete_tensorboard_time_series_non_empty_request_with_auto_populated_fi with mock.patch.object( type(client.transport.delete_tensorboard_time_series), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_tensorboard_time_series(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9373,6 +11523,50 @@ def test_delete_tensorboard_time_series_non_empty_request_with_auto_populated_fi ) +def test_delete_tensorboard_time_series_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_time_series + ] = mock_rpc + request = {} + client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_time_series_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9396,6 +11590,56 @@ async def test_delete_tensorboard_time_series_empty_call_async(): assert args[0] == tensorboard_service.DeleteTensorboardTimeSeriesRequest() +@pytest.mark.asyncio +async def test_delete_tensorboard_time_series_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tensorboard_time_series + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tensorboard_time_series + ] = mock_object + + request = {} + await client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_tensorboard_time_series_async( transport: str = "grpc_asyncio", @@ -9637,6 +11881,9 @@ def test_batch_read_tensorboard_time_series_data_empty_call(): with mock.patch.object( type(client.transport.batch_read_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_read_tensorboard_time_series_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9664,6 +11911,9 @@ def test_batch_read_tensorboard_time_series_data_non_empty_request_with_auto_pop with mock.patch.object( type(client.transport.batch_read_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.batch_read_tensorboard_time_series_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9672,6 +11922,46 @@ def test_batch_read_tensorboard_time_series_data_non_empty_request_with_auto_pop ) +def test_batch_read_tensorboard_time_series_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_read_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_read_tensorboard_time_series_data + ] = mock_rpc + request = {} + client.batch_read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_batch_read_tensorboard_time_series_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -9697,6 +11987,52 @@ async def test_batch_read_tensorboard_time_series_data_empty_call_async(): ) +@pytest.mark.asyncio +async def test_batch_read_tensorboard_time_series_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_read_tensorboard_time_series_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_read_tensorboard_time_series_data + ] = mock_object + + request = {} + await client.batch_read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.batch_read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_batch_read_tensorboard_time_series_data_async( transport: str = "grpc_asyncio", @@ -9944,6 +12280,9 @@ def test_read_tensorboard_time_series_data_empty_call(): with mock.patch.object( type(client.transport.read_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_time_series_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -9970,6 +12309,9 @@ def test_read_tensorboard_time_series_data_non_empty_request_with_auto_populated with mock.patch.object( type(client.transport.read_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_time_series_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -9979,6 +12321,46 @@ def test_read_tensorboard_time_series_data_non_empty_request_with_auto_populated ) +def test_read_tensorboard_time_series_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_time_series_data + ] = mock_rpc + request = {} + client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_time_series_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10002,6 +12384,52 @@ async def test_read_tensorboard_time_series_data_empty_call_async(): assert args[0] == tensorboard_service.ReadTensorboardTimeSeriesDataRequest() +@pytest.mark.asyncio +async def test_read_tensorboard_time_series_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_tensorboard_time_series_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_tensorboard_time_series_data + ] = mock_object + + request = {} + await client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_time_series_data_async( transport: str = "grpc_asyncio", @@ -10244,6 +12672,9 @@ def test_read_tensorboard_blob_data_empty_call(): with mock.patch.object( type(client.transport.read_tensorboard_blob_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_blob_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10269,6 +12700,9 @@ def test_read_tensorboard_blob_data_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.read_tensorboard_blob_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.read_tensorboard_blob_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10277,6 +12711,46 @@ def test_read_tensorboard_blob_data_non_empty_request_with_auto_populated_field( ) +def test_read_tensorboard_blob_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_blob_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_blob_data + ] = mock_rpc + request = {} + client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_blob_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_blob_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10301,6 +12775,52 @@ async def test_read_tensorboard_blob_data_empty_call_async(): assert args[0] == tensorboard_service.ReadTensorboardBlobDataRequest() +@pytest.mark.asyncio +async def test_read_tensorboard_blob_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.read_tensorboard_blob_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.read_tensorboard_blob_data + ] = mock_object + + request = {} + await client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.read_tensorboard_blob_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_read_tensorboard_blob_data_async( transport: str = "grpc_asyncio", @@ -10547,6 +13067,9 @@ def test_write_tensorboard_experiment_data_empty_call(): with mock.patch.object( type(client.transport.write_tensorboard_experiment_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_tensorboard_experiment_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10572,6 +13095,9 @@ def test_write_tensorboard_experiment_data_non_empty_request_with_auto_populated with mock.patch.object( type(client.transport.write_tensorboard_experiment_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_tensorboard_experiment_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10580,6 +13106,46 @@ def test_write_tensorboard_experiment_data_non_empty_request_with_auto_populated ) +def test_write_tensorboard_experiment_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_tensorboard_experiment_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_tensorboard_experiment_data + ] = mock_rpc + request = {} + client.write_tensorboard_experiment_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_tensorboard_experiment_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_write_tensorboard_experiment_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10603,6 +13169,52 @@ async def test_write_tensorboard_experiment_data_empty_call_async(): assert args[0] == tensorboard_service.WriteTensorboardExperimentDataRequest() +@pytest.mark.asyncio +async def test_write_tensorboard_experiment_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.write_tensorboard_experiment_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.write_tensorboard_experiment_data + ] = mock_object + + request = {} + await client.write_tensorboard_experiment_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.write_tensorboard_experiment_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_write_tensorboard_experiment_data_async( transport: str = "grpc_asyncio", @@ -10876,6 +13488,9 @@ def test_write_tensorboard_run_data_empty_call(): with mock.patch.object( type(client.transport.write_tensorboard_run_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_tensorboard_run_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -10901,6 +13516,9 @@ def test_write_tensorboard_run_data_non_empty_request_with_auto_populated_field( with mock.patch.object( type(client.transport.write_tensorboard_run_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.write_tensorboard_run_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -10909,6 +13527,46 @@ def test_write_tensorboard_run_data_non_empty_request_with_auto_populated_field( ) +def test_write_tensorboard_run_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_tensorboard_run_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_tensorboard_run_data + ] = mock_rpc + request = {} + client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_tensorboard_run_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_write_tensorboard_run_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -10932,6 +13590,52 @@ async def test_write_tensorboard_run_data_empty_call_async(): assert args[0] == tensorboard_service.WriteTensorboardRunDataRequest() +@pytest.mark.asyncio +async def test_write_tensorboard_run_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.write_tensorboard_run_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.write_tensorboard_run_data + ] = mock_object + + request = {} + await client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.write_tensorboard_run_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_write_tensorboard_run_data_async( transport: str = "grpc_asyncio", @@ -11206,6 +13910,9 @@ def test_export_tensorboard_time_series_data_empty_call(): with mock.patch.object( type(client.transport.export_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_tensorboard_time_series_data() call.assert_called() _, args, _ = call.mock_calls[0] @@ -11234,6 +13941,9 @@ def test_export_tensorboard_time_series_data_non_empty_request_with_auto_populat with mock.patch.object( type(client.transport.export_tensorboard_time_series_data), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.export_tensorboard_time_series_data(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -11245,6 +13955,46 @@ def test_export_tensorboard_time_series_data_non_empty_request_with_auto_populat ) +def test_export_tensorboard_time_series_data_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.export_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.export_tensorboard_time_series_data + ] = mock_rpc + request = {} + client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.export_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_export_tensorboard_time_series_data_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -11270,6 +14020,52 @@ async def test_export_tensorboard_time_series_data_empty_call_async(): assert args[0] == tensorboard_service.ExportTensorboardTimeSeriesDataRequest() +@pytest.mark.asyncio +async def test_export_tensorboard_time_series_data_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TensorboardServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.export_tensorboard_time_series_data + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.export_tensorboard_time_series_data + ] = mock_object + + request = {} + await client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.export_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_export_tensorboard_time_series_data_async( transport: str = "grpc_asyncio", @@ -11786,6 +14582,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_tensorboard_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard + ] = mock_rpc + + request = {} + client.create_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tensorboard_rest_required_fields( request_type=tensorboard_service.CreateTensorboardRequest, ): @@ -12073,6 +14913,42 @@ def test_get_tensorboard_rest(request_type): assert response.is_default is True +def test_get_tensorboard_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_tensorboard in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_tensorboard] = mock_rpc + + request = {} + client.get_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tensorboard_rest_required_fields( request_type=tensorboard_service.GetTensorboardRequest, ): @@ -12421,6 +15297,50 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_update_tensorboard_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard + ] = mock_rpc + + request = {} + client.update_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_tensorboard_rest_required_fields( request_type=tensorboard_service.UpdateTensorboardRequest, ): @@ -12701,6 +15621,44 @@ def test_list_tensorboards_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tensorboards_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_tensorboards in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboards + ] = mock_rpc + + request = {} + client.list_tensorboards(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboards(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tensorboards_rest_required_fields( request_type=tensorboard_service.ListTensorboardsRequest, ): @@ -13043,6 +16001,50 @@ def test_delete_tensorboard_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_tensorboard_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard + ] = mock_rpc + + request = {} + client.delete_tensorboard(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_tensorboard_rest_required_fields( request_type=tensorboard_service.DeleteTensorboardRequest, ): @@ -13308,6 +16310,47 @@ def test_read_tensorboard_usage_rest(request_type): assert isinstance(response, tensorboard_service.ReadTensorboardUsageResponse) +def test_read_tensorboard_usage_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_usage + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_usage + ] = mock_rpc + + request = {} + client.read_tensorboard_usage(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_usage(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_tensorboard_usage_rest_required_fields( request_type=tensorboard_service.ReadTensorboardUsageRequest, ): @@ -13586,6 +16629,47 @@ def test_read_tensorboard_size_rest(request_type): assert response.storage_size_byte == 1826 +def test_read_tensorboard_size_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_size + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_size + ] = mock_rpc + + request = {} + client.read_tensorboard_size(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_size(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_tensorboard_size_rest_required_fields( request_type=tensorboard_service.ReadTensorboardSizeRequest, ): @@ -13950,6 +17034,47 @@ def get_message_fields(field): assert response.source == "source_value" +def test_create_tensorboard_experiment_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_experiment + ] = mock_rpc + + request = {} + client.create_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tensorboard_experiment_rest_required_fields( request_type=tensorboard_service.CreateTensorboardExperimentRequest, ): @@ -14275,6 +17400,47 @@ def test_get_tensorboard_experiment_rest(request_type): assert response.source == "source_value" +def test_get_tensorboard_experiment_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_experiment + ] = mock_rpc + + request = {} + client.get_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tensorboard_experiment_rest_required_fields( request_type=tensorboard_service.GetTensorboardExperimentRequest, ): @@ -14642,6 +17808,47 @@ def get_message_fields(field): assert response.source == "source_value" +def test_update_tensorboard_experiment_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_experiment + ] = mock_rpc + + request = {} + client.update_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_tensorboard_experiment_rest_required_fields( request_type=tensorboard_service.UpdateTensorboardExperimentRequest, ): @@ -14940,6 +18147,47 @@ def test_list_tensorboard_experiments_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tensorboard_experiments_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_experiments + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_experiments + ] = mock_rpc + + request = {} + client.list_tensorboard_experiments(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_experiments(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tensorboard_experiments_rest_required_fields( request_type=tensorboard_service.ListTensorboardExperimentsRequest, ): @@ -15297,6 +18545,51 @@ def test_delete_tensorboard_experiment_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_tensorboard_experiment_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_experiment + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_experiment + ] = mock_rpc + + request = {} + client.delete_tensorboard_experiment(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_experiment(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_tensorboard_experiment_rest_required_fields( request_type=tensorboard_service.DeleteTensorboardExperimentRequest, ): @@ -15656,6 +18949,47 @@ def get_message_fields(field): assert response.etag == "etag_value" +def test_create_tensorboard_run_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_run + ] = mock_rpc + + request = {} + client.create_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tensorboard_run_rest_required_fields( request_type=tensorboard_service.CreateTensorboardRunRequest, ): @@ -15957,6 +19291,47 @@ def test_batch_create_tensorboard_runs_rest(request_type): assert isinstance(response, tensorboard_service.BatchCreateTensorboardRunsResponse) +def test_batch_create_tensorboard_runs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_tensorboard_runs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_tensorboard_runs + ] = mock_rpc + + request = {} + client.batch_create_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_create_tensorboard_runs_rest_required_fields( request_type=tensorboard_service.BatchCreateTensorboardRunsRequest, ): @@ -16262,6 +19637,46 @@ def test_get_tensorboard_run_rest(request_type): assert response.etag == "etag_value" +def test_get_tensorboard_run_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_run in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_run + ] = mock_rpc + + request = {} + client.get_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tensorboard_run_rest_required_fields( request_type=tensorboard_service.GetTensorboardRunRequest, ): @@ -16621,6 +20036,47 @@ def get_message_fields(field): assert response.etag == "etag_value" +def test_update_tensorboard_run_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_run + ] = mock_rpc + + request = {} + client.update_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_tensorboard_run_rest_required_fields( request_type=tensorboard_service.UpdateTensorboardRunRequest, ): @@ -16907,6 +20363,47 @@ def test_list_tensorboard_runs_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tensorboard_runs_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_runs + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_runs + ] = mock_rpc + + request = {} + client.list_tensorboard_runs(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_runs(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tensorboard_runs_rest_required_fields( request_type=tensorboard_service.ListTensorboardRunsRequest, ): @@ -17259,6 +20756,51 @@ def test_delete_tensorboard_run_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_tensorboard_run_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_run + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_run + ] = mock_rpc + + request = {} + client.delete_tensorboard_run(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_run(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_tensorboard_run_rest_required_fields( request_type=tensorboard_service.DeleteTensorboardRunRequest, ): @@ -17531,6 +21073,47 @@ def test_batch_create_tensorboard_time_series_rest(request_type): ) +def test_batch_create_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_tensorboard_time_series + ] = mock_rpc + + request = {} + client.batch_create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_create_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.BatchCreateTensorboardTimeSeriesRequest, ): @@ -17944,6 +21527,47 @@ def get_message_fields(field): assert response.plugin_data == b"plugin_data_blob" +def test_create_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tensorboard_time_series + ] = mock_rpc + + request = {} + client.create_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.CreateTensorboardTimeSeriesRequest, ): @@ -18260,6 +21884,47 @@ def test_get_tensorboard_time_series_rest(request_type): assert response.plugin_data == b"plugin_data_blob" +def test_get_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_tensorboard_time_series + ] = mock_rpc + + request = {} + client.get_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.GetTensorboardTimeSeriesRequest, ): @@ -18644,6 +22309,47 @@ def get_message_fields(field): assert response.plugin_data == b"plugin_data_blob" +def test_update_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tensorboard_time_series + ] = mock_rpc + + request = {} + client.update_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_update_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.UpdateTensorboardTimeSeriesRequest, ): @@ -18946,6 +22652,47 @@ def test_list_tensorboard_time_series_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tensorboard_time_series + ] = mock_rpc + + request = {} + client.list_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.ListTensorboardTimeSeriesRequest, ): @@ -19306,6 +23053,51 @@ def test_delete_tensorboard_time_series_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_tensorboard_time_series_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tensorboard_time_series + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tensorboard_time_series + ] = mock_rpc + + request = {} + client.delete_tensorboard_time_series(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_tensorboard_time_series(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_tensorboard_time_series_rest_required_fields( request_type=tensorboard_service.DeleteTensorboardTimeSeriesRequest, ): @@ -19584,6 +23376,47 @@ def test_batch_read_tensorboard_time_series_data_rest(request_type): ) +def test_batch_read_tensorboard_time_series_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_read_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_read_tensorboard_time_series_data + ] = mock_rpc + + request = {} + client.batch_read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_batch_read_tensorboard_time_series_data_rest_required_fields( request_type=tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest, ): @@ -19903,6 +23736,47 @@ def test_read_tensorboard_time_series_data_rest(request_type): ) +def test_read_tensorboard_time_series_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_time_series_data + ] = mock_rpc + + request = {} + client.read_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_tensorboard_time_series_data_rest_required_fields( request_type=tensorboard_service.ReadTensorboardTimeSeriesDataRequest, ): @@ -20210,6 +24084,47 @@ def test_read_tensorboard_blob_data_rest(request_type): assert isinstance(response, tensorboard_service.ReadTensorboardBlobDataResponse) +def test_read_tensorboard_blob_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.read_tensorboard_blob_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.read_tensorboard_blob_data + ] = mock_rpc + + request = {} + client.read_tensorboard_blob_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.read_tensorboard_blob_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_read_tensorboard_blob_data_rest_required_fields( request_type=tensorboard_service.ReadTensorboardBlobDataRequest, ): @@ -20500,6 +24415,47 @@ def test_write_tensorboard_experiment_data_rest(request_type): ) +def test_write_tensorboard_experiment_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_tensorboard_experiment_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_tensorboard_experiment_data + ] = mock_rpc + + request = {} + client.write_tensorboard_experiment_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_tensorboard_experiment_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_write_tensorboard_experiment_data_rest_required_fields( request_type=tensorboard_service.WriteTensorboardExperimentDataRequest, ): @@ -20806,6 +24762,47 @@ def test_write_tensorboard_run_data_rest(request_type): assert isinstance(response, tensorboard_service.WriteTensorboardRunDataResponse) +def test_write_tensorboard_run_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.write_tensorboard_run_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.write_tensorboard_run_data + ] = mock_rpc + + request = {} + client.write_tensorboard_run_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.write_tensorboard_run_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_write_tensorboard_run_data_rest_required_fields( request_type=tensorboard_service.WriteTensorboardRunDataRequest, ): @@ -21107,6 +25104,47 @@ def test_export_tensorboard_time_series_data_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_export_tensorboard_time_series_data_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TensorboardServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.export_tensorboard_time_series_data + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.export_tensorboard_time_series_data + ] = mock_rpc + + request = {} + client.export_tensorboard_time_series_data(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.export_tensorboard_time_series_data(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_export_tensorboard_time_series_data_rest_required_fields( request_type=tensorboard_service.ExportTensorboardTimeSeriesDataRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py index 4e4bb0feae..c2091a1cfe 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py @@ -1253,6 +1253,9 @@ def test_create_rag_corpus_empty_call(): with mock.patch.object( type(client.transport.create_rag_corpus), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1278,6 +1281,9 @@ def test_create_rag_corpus_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.create_rag_corpus), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_rag_corpus(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1286,6 +1292,47 @@ def test_create_rag_corpus_non_empty_request_with_auto_populated_field(): ) +def test_create_rag_corpus_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_rag_corpus + ] = mock_rpc + request = {} + client.create_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_rag_corpus_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1309,6 +1356,56 @@ async def test_create_rag_corpus_empty_call_async(): assert args[0] == vertex_rag_data_service.CreateRagCorpusRequest() +@pytest.mark.asyncio +async def test_create_rag_corpus_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_rag_corpus + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_rag_corpus + ] = mock_object + + request = {} + await client.create_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_rag_corpus_async( transport: str = "grpc_asyncio", @@ -1559,6 +1656,9 @@ def test_get_rag_corpus_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1582,6 +1682,9 @@ def test_get_rag_corpus_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_rag_corpus), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_rag_corpus(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1590,6 +1693,41 @@ def test_get_rag_corpus_non_empty_request_with_auto_populated_field(): ) +def test_get_rag_corpus_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_rag_corpus] = mock_rpc + request = {} + client.get_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_rag_corpus_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1615,6 +1753,52 @@ async def test_get_rag_corpus_empty_call_async(): assert args[0] == vertex_rag_data_service.GetRagCorpusRequest() +@pytest.mark.asyncio +async def test_get_rag_corpus_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_rag_corpus + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_rag_corpus + ] = mock_object + + request = {} + await client.get_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_rag_corpus_async( transport: str = "grpc_asyncio", @@ -1848,6 +2032,9 @@ def test_list_rag_corpora_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_rag_corpora() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1872,6 +2059,9 @@ def test_list_rag_corpora_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_rag_corpora), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_rag_corpora(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1881,6 +2071,43 @@ def test_list_rag_corpora_non_empty_request_with_auto_populated_field(): ) +def test_list_rag_corpora_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_rag_corpora in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_rag_corpora + ] = mock_rpc + request = {} + client.list_rag_corpora(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_rag_corpora(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_rag_corpora_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1904,6 +2131,52 @@ async def test_list_rag_corpora_empty_call_async(): assert args[0] == vertex_rag_data_service.ListRagCorporaRequest() +@pytest.mark.asyncio +async def test_list_rag_corpora_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_rag_corpora + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_rag_corpora + ] = mock_object + + request = {} + await client.list_rag_corpora(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_rag_corpora(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_rag_corpora_async( transport: str = "grpc_asyncio", @@ -2324,6 +2597,9 @@ def test_delete_rag_corpus_empty_call(): with mock.patch.object( type(client.transport.delete_rag_corpus), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_rag_corpus() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2349,6 +2625,9 @@ def test_delete_rag_corpus_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.delete_rag_corpus), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_rag_corpus(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2357,6 +2636,47 @@ def test_delete_rag_corpus_non_empty_request_with_auto_populated_field(): ) +def test_delete_rag_corpus_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_rag_corpus + ] = mock_rpc + request = {} + client.delete_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_rag_corpus_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2380,6 +2700,56 @@ async def test_delete_rag_corpus_empty_call_async(): assert args[0] == vertex_rag_data_service.DeleteRagCorpusRequest() +@pytest.mark.asyncio +async def test_delete_rag_corpus_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_rag_corpus + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_rag_corpus + ] = mock_object + + request = {} + await client.delete_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_rag_corpus_async( transport: str = "grpc_asyncio", @@ -2613,6 +2983,9 @@ def test_upload_rag_file_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upload_rag_file() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2636,6 +3009,9 @@ def test_upload_rag_file_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.upload_rag_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.upload_rag_file(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2644,6 +3020,41 @@ def test_upload_rag_file_non_empty_request_with_auto_populated_field(): ) +def test_upload_rag_file_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upload_rag_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.upload_rag_file] = mock_rpc + request = {} + client.upload_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.upload_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_upload_rag_file_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2665,6 +3076,52 @@ async def test_upload_rag_file_empty_call_async(): assert args[0] == vertex_rag_data_service.UploadRagFileRequest() +@pytest.mark.asyncio +async def test_upload_rag_file_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.upload_rag_file + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.upload_rag_file + ] = mock_object + + request = {} + await client.upload_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.upload_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_upload_rag_file_async( transport: str = "grpc_asyncio", @@ -2940,6 +3397,9 @@ def test_import_rag_files_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_rag_files() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2963,6 +3423,9 @@ def test_import_rag_files_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.import_rag_files), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.import_rag_files(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2971,6 +3434,47 @@ def test_import_rag_files_non_empty_request_with_auto_populated_field(): ) +def test_import_rag_files_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_rag_files in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_rag_files + ] = mock_rpc + request = {} + client.import_rag_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_rag_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_import_rag_files_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2992,6 +3496,56 @@ async def test_import_rag_files_empty_call_async(): assert args[0] == vertex_rag_data_service.ImportRagFilesRequest() +@pytest.mark.asyncio +async def test_import_rag_files_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.import_rag_files + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.import_rag_files + ] = mock_object + + request = {} + await client.import_rag_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.import_rag_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_import_rag_files_async( transport: str = "grpc_asyncio", @@ -3250,6 +3804,9 @@ def test_get_rag_file_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_rag_file() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3273,6 +3830,9 @@ def test_get_rag_file_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_rag_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_rag_file(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3281,6 +3841,41 @@ def test_get_rag_file_non_empty_request_with_auto_populated_field(): ) +def test_get_rag_file_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_rag_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_rag_file] = mock_rpc + request = {} + client.get_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_rag_file_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3308,6 +3903,52 @@ async def test_get_rag_file_empty_call_async(): assert args[0] == vertex_rag_data_service.GetRagFileRequest() +@pytest.mark.asyncio +async def test_get_rag_file_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_rag_file + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_rag_file + ] = mock_object + + request = {} + await client.get_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_rag_file_async( transport: str = "grpc_asyncio", @@ -3547,6 +4188,9 @@ def test_list_rag_files_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_rag_files() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3571,6 +4215,9 @@ def test_list_rag_files_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_rag_files), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_rag_files(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3580,6 +4227,41 @@ def test_list_rag_files_non_empty_request_with_auto_populated_field(): ) +def test_list_rag_files_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_rag_files in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_rag_files] = mock_rpc + request = {} + client.list_rag_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_rag_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_rag_files_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3603,6 +4285,52 @@ async def test_list_rag_files_empty_call_async(): assert args[0] == vertex_rag_data_service.ListRagFilesRequest() +@pytest.mark.asyncio +async def test_list_rag_files_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_rag_files + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_rag_files + ] = mock_object + + request = {} + await client.list_rag_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_rag_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_rag_files_async( transport: str = "grpc_asyncio", @@ -4019,6 +4747,9 @@ def test_delete_rag_file_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_rag_file() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4042,6 +4773,9 @@ def test_delete_rag_file_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_rag_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_rag_file(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4050,6 +4784,45 @@ def test_delete_rag_file_non_empty_request_with_auto_populated_field(): ) +def test_delete_rag_file_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_rag_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_rag_file] = mock_rpc + request = {} + client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_rag_file_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4071,6 +4844,56 @@ async def test_delete_rag_file_empty_call_async(): assert args[0] == vertex_rag_data_service.DeleteRagFileRequest() +@pytest.mark.asyncio +async def test_delete_rag_file_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_rag_file + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_rag_file + ] = mock_object + + request = {} + await client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.delete_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_rag_file_async( transport: str = "grpc_asyncio", @@ -4362,6 +5185,48 @@ def get_message_fields(field): assert response.operation.name == "operations/spam" +def test_create_rag_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_rag_corpus + ] = mock_rpc + + request = {} + client.create_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_rag_corpus_rest_required_fields( request_type=vertex_rag_data_service.CreateRagCorpusRequest, ): @@ -4641,6 +5506,42 @@ def test_get_rag_corpus_rest(request_type): assert response.description == "description_value" +def test_get_rag_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_rag_corpus] = mock_rpc + + request = {} + client.get_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_rag_corpus_rest_required_fields( request_type=vertex_rag_data_service.GetRagCorpusRequest, ): @@ -4910,6 +5811,44 @@ def test_list_rag_corpora_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_rag_corpora_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_rag_corpora in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_rag_corpora + ] = mock_rpc + + request = {} + client.list_rag_corpora(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_rag_corpora(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_rag_corpora_rest_required_fields( request_type=vertex_rag_data_service.ListRagCorporaRequest, ): @@ -5248,6 +6187,48 @@ def test_delete_rag_corpus_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_rag_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_rag_corpus + ] = mock_rpc + + request = {} + client.delete_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_rag_corpus_rest_required_fields( request_type=vertex_rag_data_service.DeleteRagCorpusRequest, ): @@ -5513,6 +6494,42 @@ def test_upload_rag_file_rest(request_type): assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) +def test_upload_rag_file_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.upload_rag_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.upload_rag_file] = mock_rpc + + request = {} + client.upload_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.upload_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_upload_rag_file_rest_required_fields( request_type=vertex_rag_data_service.UploadRagFileRequest, ): @@ -5807,6 +6824,48 @@ def test_import_rag_files_rest(request_type): assert response.operation.name == "operations/spam" +def test_import_rag_files_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.import_rag_files in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.import_rag_files + ] = mock_rpc + + request = {} + client.import_rag_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.import_rag_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_import_rag_files_rest_required_fields( request_type=vertex_rag_data_service.ImportRagFilesRequest, ): @@ -6100,6 +7159,42 @@ def test_get_rag_file_rest(request_type): ) +def test_get_rag_file_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_rag_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_rag_file] = mock_rpc + + request = {} + client.get_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_rag_file_rest_required_fields( request_type=vertex_rag_data_service.GetRagFileRequest, ): @@ -6371,6 +7466,42 @@ def test_list_rag_files_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_rag_files_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_rag_files in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_rag_files] = mock_rpc + + request = {} + client.list_rag_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_rag_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_rag_files_rest_required_fields( request_type=vertex_rag_data_service.ListRagFilesRequest, ): @@ -6713,6 +7844,46 @@ def test_delete_rag_file_rest(request_type): assert response.operation.name == "operations/spam" +def test_delete_rag_file_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_rag_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_rag_file] = mock_rpc + + request = {} + client.delete_rag_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_rag_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_rag_file_rest_required_fields( request_type=vertex_rag_data_service.DeleteRagFileRequest, ): diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_service.py index 67a9b32b53..c9a0695c23 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_service.py @@ -1204,6 +1204,9 @@ def test_retrieve_contexts_empty_call(): with mock.patch.object( type(client.transport.retrieve_contexts), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.retrieve_contexts() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1229,6 +1232,9 @@ def test_retrieve_contexts_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.retrieve_contexts), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.retrieve_contexts(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1237,6 +1243,43 @@ def test_retrieve_contexts_non_empty_request_with_auto_populated_field(): ) +def test_retrieve_contexts_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.retrieve_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.retrieve_contexts + ] = mock_rpc + request = {} + client.retrieve_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.retrieve_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_retrieve_contexts_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1260,6 +1303,52 @@ async def test_retrieve_contexts_empty_call_async(): assert args[0] == vertex_rag_service.RetrieveContextsRequest() +@pytest.mark.asyncio +async def test_retrieve_contexts_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.retrieve_contexts + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.retrieve_contexts + ] = mock_object + + request = {} + await client.retrieve_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.retrieve_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_retrieve_contexts_async( transport: str = "grpc_asyncio", @@ -1497,6 +1586,44 @@ def test_retrieve_contexts_rest(request_type): assert isinstance(response, vertex_rag_service.RetrieveContextsResponse) +def test_retrieve_contexts_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.retrieve_contexts in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.retrieve_contexts + ] = mock_rpc + + request = {} + client.retrieve_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.retrieve_contexts(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_retrieve_contexts_rest_required_fields( request_type=vertex_rag_service.RetrieveContextsRequest, ): @@ -2274,8 +2401,34 @@ def test_vertex_rag_service_transport_channel_mtls_with_adc(transport_class): assert transport.grpc_channel == mock_grpc_channel +def test_rag_corpus_path(): + project = "squid" + location = "clam" + rag_corpus = "whelk" + expected = "projects/{project}/locations/{location}/ragCorpora/{rag_corpus}".format( + project=project, + location=location, + rag_corpus=rag_corpus, + ) + actual = VertexRagServiceClient.rag_corpus_path(project, location, rag_corpus) + assert expected == actual + + +def test_parse_rag_corpus_path(): + expected = { + "project": "octopus", + "location": "oyster", + "rag_corpus": "nudibranch", + } + path = VertexRagServiceClient.rag_corpus_path(**expected) + + # Check that the path construction is reversible. + actual = VertexRagServiceClient.parse_rag_corpus_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "squid" + billing_account = "cuttlefish" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -2285,7 +2438,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "mussel", } path = VertexRagServiceClient.common_billing_account_path(**expected) @@ -2295,7 +2448,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "whelk" + folder = "winkle" expected = "folders/{folder}".format( folder=folder, ) @@ -2305,7 +2458,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "nautilus", } path = VertexRagServiceClient.common_folder_path(**expected) @@ -2315,7 +2468,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "oyster" + organization = "scallop" expected = "organizations/{organization}".format( organization=organization, ) @@ -2325,7 +2478,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "abalone", } path = VertexRagServiceClient.common_organization_path(**expected) @@ -2335,7 +2488,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "cuttlefish" + project = "squid" expected = "projects/{project}".format( project=project, ) @@ -2345,7 +2498,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "clam", } path = VertexRagServiceClient.common_project_path(**expected) @@ -2355,8 +2508,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "winkle" - location = "nautilus" + project = "whelk" + location = "octopus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -2367,8 +2520,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "oyster", + "location": "nudibranch", } path = VertexRagServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py index 57552d7b23..025f07d83e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py @@ -1184,6 +1184,9 @@ def test_create_study_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_study() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1207,6 +1210,9 @@ def test_create_study_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_study(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1215,6 +1221,41 @@ def test_create_study_non_empty_request_with_auto_populated_field(): ) +def test_create_study_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_study] = mock_rpc + request = {} + client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_study_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1241,6 +1282,52 @@ async def test_create_study_empty_call_async(): assert args[0] == vizier_service.CreateStudyRequest() +@pytest.mark.asyncio +async def test_create_study_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_study + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_study + ] = mock_object + + request = {} + await client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_create_study_async( transport: str = "grpc_asyncio", request_type=vizier_service.CreateStudyRequest @@ -1487,6 +1574,9 @@ def test_get_study_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_study() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1510,6 +1600,9 @@ def test_get_study_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_study(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1518,6 +1611,41 @@ def test_get_study_non_empty_request_with_auto_populated_field(): ) +def test_get_study_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_study] = mock_rpc + request = {} + client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_study_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1544,6 +1672,50 @@ async def test_get_study_empty_call_async(): assert args[0] == vizier_service.GetStudyRequest() +@pytest.mark.asyncio +async def test_get_study_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_study + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_study + ] = mock_object + + request = {} + await client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_study_async( transport: str = "grpc_asyncio", request_type=vizier_service.GetStudyRequest @@ -1774,6 +1946,9 @@ def test_list_studies_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_studies() call.assert_called() _, args, _ = call.mock_calls[0] @@ -1798,6 +1973,9 @@ def test_list_studies_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_studies(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -1807,6 +1985,41 @@ def test_list_studies_non_empty_request_with_auto_populated_field(): ) +def test_list_studies_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_studies in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_studies] = mock_rpc + request = {} + client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_studies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_studies_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -1830,6 +2043,52 @@ async def test_list_studies_empty_call_async(): assert args[0] == vizier_service.ListStudiesRequest() +@pytest.mark.asyncio +async def test_list_studies_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_studies + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_studies + ] = mock_object + + request = {} + await client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_studies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_studies_async( transport: str = "grpc_asyncio", request_type=vizier_service.ListStudiesRequest @@ -2245,6 +2504,9 @@ def test_delete_study_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_study() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2268,6 +2530,9 @@ def test_delete_study_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_study(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2276,6 +2541,41 @@ def test_delete_study_non_empty_request_with_auto_populated_field(): ) +def test_delete_study_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_study] = mock_rpc + request = {} + client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_study_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2295,6 +2595,52 @@ async def test_delete_study_empty_call_async(): assert args[0] == vizier_service.DeleteStudyRequest() +@pytest.mark.asyncio +async def test_delete_study_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_study + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_study + ] = mock_object + + request = {} + await client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.delete_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_study_async( transport: str = "grpc_asyncio", request_type=vizier_service.DeleteStudyRequest @@ -2520,6 +2866,9 @@ def test_lookup_study_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.lookup_study() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2544,6 +2893,9 @@ def test_lookup_study_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.lookup_study(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2553,6 +2905,41 @@ def test_lookup_study_non_empty_request_with_auto_populated_field(): ) +def test_lookup_study_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.lookup_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.lookup_study] = mock_rpc + request = {} + client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.lookup_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_lookup_study_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2579,6 +2966,52 @@ async def test_lookup_study_empty_call_async(): assert args[0] == vizier_service.LookupStudyRequest() +@pytest.mark.asyncio +async def test_lookup_study_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.lookup_study + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.lookup_study + ] = mock_object + + request = {} + await client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.lookup_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_lookup_study_async( transport: str = "grpc_asyncio", request_type=vizier_service.LookupStudyRequest @@ -2806,6 +3239,9 @@ def test_suggest_trials_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.suggest_trials() call.assert_called() _, args, _ = call.mock_calls[0] @@ -2830,6 +3266,9 @@ def test_suggest_trials_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.suggest_trials(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -2839,6 +3278,45 @@ def test_suggest_trials_non_empty_request_with_auto_populated_field(): ) +def test_suggest_trials_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.suggest_trials in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.suggest_trials] = mock_rpc + request = {} + client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.suggest_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_suggest_trials_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -2860,6 +3338,56 @@ async def test_suggest_trials_empty_call_async(): assert args[0] == vizier_service.SuggestTrialsRequest() +@pytest.mark.asyncio +async def test_suggest_trials_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.suggest_trials + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.suggest_trials + ] = mock_object + + request = {} + await client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.suggest_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_suggest_trials_async( transport: str = "grpc_asyncio", request_type=vizier_service.SuggestTrialsRequest @@ -3013,6 +3541,9 @@ def test_create_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3036,6 +3567,9 @@ def test_create_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.create_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3044,6 +3578,41 @@ def test_create_trial_non_empty_request_with_auto_populated_field(): ) +def test_create_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_trial] = mock_rpc + request = {} + client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_create_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3073,15 +3642,61 @@ async def test_create_trial_empty_call_async(): @pytest.mark.asyncio -async def test_create_trial_async( - transport: str = "grpc_asyncio", request_type=vizier_service.CreateTrialRequest +async def test_create_trial_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", ): - client = VizierServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) - # Everything is optional in proto3 as far as the runtime is concerned, + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.create_trial + ] = mock_object + + request = {} + await client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.create_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.CreateTrialRequest +): + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. request = request_type() @@ -3326,6 +3941,9 @@ def test_get_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3349,6 +3967,9 @@ def test_get_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.get_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3357,6 +3978,41 @@ def test_get_trial_non_empty_request_with_auto_populated_field(): ) +def test_get_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_trial] = mock_rpc + request = {} + client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_get_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3385,6 +4041,50 @@ async def test_get_trial_empty_call_async(): assert args[0] == vizier_service.GetTrialRequest() +@pytest.mark.asyncio +async def test_get_trial_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.get_trial + ] = mock_object + + request = {} + await client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.get_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_get_trial_async( transport: str = "grpc_asyncio", request_type=vizier_service.GetTrialRequest @@ -3619,6 +4319,9 @@ def test_list_trials_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_trials() call.assert_called() _, args, _ = call.mock_calls[0] @@ -3643,6 +4346,9 @@ def test_list_trials_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_trials(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -3652,6 +4358,41 @@ def test_list_trials_non_empty_request_with_auto_populated_field(): ) +def test_list_trials_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_trials in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_trials] = mock_rpc + request = {} + client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_trials_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -3675,6 +4416,52 @@ async def test_list_trials_empty_call_async(): assert args[0] == vizier_service.ListTrialsRequest() +@pytest.mark.asyncio +async def test_list_trials_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_trials + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_trials + ] = mock_object + + request = {} + await client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_trials_async( transport: str = "grpc_asyncio", request_type=vizier_service.ListTrialsRequest @@ -4107,6 +4894,9 @@ def test_add_trial_measurement_empty_call(): with mock.patch.object( type(client.transport.add_trial_measurement), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_trial_measurement() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4132,6 +4922,9 @@ def test_add_trial_measurement_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.add_trial_measurement), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.add_trial_measurement(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4140,6 +4933,46 @@ def test_add_trial_measurement_non_empty_request_with_auto_populated_field(): ) +def test_add_trial_measurement_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_trial_measurement + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_trial_measurement + ] = mock_rpc + request = {} + client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_trial_measurement(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_add_trial_measurement_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4170,6 +5003,52 @@ async def test_add_trial_measurement_empty_call_async(): assert args[0] == vizier_service.AddTrialMeasurementRequest() +@pytest.mark.asyncio +async def test_add_trial_measurement_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.add_trial_measurement + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.add_trial_measurement + ] = mock_object + + request = {} + await client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.add_trial_measurement(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_add_trial_measurement_async( transport: str = "grpc_asyncio", @@ -4341,6 +5220,9 @@ def test_complete_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.complete_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4365,6 +5247,9 @@ def test_complete_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.complete_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4374,6 +5259,41 @@ def test_complete_trial_non_empty_request_with_auto_populated_field(): ) +def test_complete_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.complete_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.complete_trial] = mock_rpc + request = {} + client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.complete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_complete_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4402,6 +5322,52 @@ async def test_complete_trial_empty_call_async(): assert args[0] == vizier_service.CompleteTrialRequest() +@pytest.mark.asyncio +async def test_complete_trial_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.complete_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.complete_trial + ] = mock_object + + request = {} + await client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.complete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_complete_trial_async( transport: str = "grpc_asyncio", request_type=vizier_service.CompleteTrialRequest @@ -4553,6 +5519,9 @@ def test_delete_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4576,6 +5545,9 @@ def test_delete_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.delete_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4584,6 +5556,41 @@ def test_delete_trial_non_empty_request_with_auto_populated_field(): ) +def test_delete_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_trial] = mock_rpc + request = {} + client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_delete_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4603,6 +5610,52 @@ async def test_delete_trial_empty_call_async(): assert args[0] == vizier_service.DeleteTrialRequest() +@pytest.mark.asyncio +async def test_delete_trial_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_trial + ] = mock_object + + request = {} + await client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.delete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_delete_trial_async( transport: str = "grpc_asyncio", request_type=vizier_service.DeleteTrialRequest @@ -4823,6 +5876,9 @@ def test_check_trial_early_stopping_state_empty_call(): with mock.patch.object( type(client.transport.check_trial_early_stopping_state), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.check_trial_early_stopping_state() call.assert_called() _, args, _ = call.mock_calls[0] @@ -4848,6 +5904,9 @@ def test_check_trial_early_stopping_state_non_empty_request_with_auto_populated_ with mock.patch.object( type(client.transport.check_trial_early_stopping_state), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.check_trial_early_stopping_state(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -4856,6 +5915,50 @@ def test_check_trial_early_stopping_state_non_empty_request_with_auto_populated_ ) +def test_check_trial_early_stopping_state_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.check_trial_early_stopping_state + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.check_trial_early_stopping_state + ] = mock_rpc + request = {} + client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.check_trial_early_stopping_state(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_check_trial_early_stopping_state_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -4879,6 +5982,56 @@ async def test_check_trial_early_stopping_state_empty_call_async(): assert args[0] == vizier_service.CheckTrialEarlyStoppingStateRequest() +@pytest.mark.asyncio +async def test_check_trial_early_stopping_state_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.check_trial_early_stopping_state + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.check_trial_early_stopping_state + ] = mock_object + + request = {} + await client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.check_trial_early_stopping_state(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_check_trial_early_stopping_state_async( transport: str = "grpc_asyncio", @@ -5039,6 +6192,9 @@ def test_stop_trial_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stop_trial() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5062,6 +6218,9 @@ def test_stop_trial_non_empty_request_with_auto_populated_field(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.stop_trial(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5070,6 +6229,41 @@ def test_stop_trial_non_empty_request_with_auto_populated_field(): ) +def test_stop_trial_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.stop_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.stop_trial] = mock_rpc + request = {} + client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stop_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_stop_trial_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5098,6 +6292,50 @@ async def test_stop_trial_empty_call_async(): assert args[0] == vizier_service.StopTrialRequest() +@pytest.mark.asyncio +async def test_stop_trial_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stop_trial + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.stop_trial + ] = mock_object + + request = {} + await client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.stop_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_stop_trial_async( transport: str = "grpc_asyncio", request_type=vizier_service.StopTrialRequest @@ -5253,6 +6491,9 @@ def test_list_optimal_trials_empty_call(): with mock.patch.object( type(client.transport.list_optimal_trials), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_optimal_trials() call.assert_called() _, args, _ = call.mock_calls[0] @@ -5278,6 +6519,9 @@ def test_list_optimal_trials_non_empty_request_with_auto_populated_field(): with mock.patch.object( type(client.transport.list_optimal_trials), "__call__" ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) client.list_optimal_trials(request=request) call.assert_called() _, args, _ = call.mock_calls[0] @@ -5286,6 +6530,45 @@ def test_list_optimal_trials_non_empty_request_with_auto_populated_field(): ) +def test_list_optimal_trials_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_optimal_trials in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_optimal_trials + ] = mock_rpc + request = {} + client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_optimal_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + @pytest.mark.asyncio async def test_list_optimal_trials_empty_call_async(): # This test is a coverage failsafe to make sure that totally empty calls, @@ -5309,6 +6592,52 @@ async def test_list_optimal_trials_empty_call_async(): assert args[0] == vizier_service.ListOptimalTrialsRequest() +@pytest.mark.asyncio +async def test_list_optimal_trials_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VizierServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_optimal_trials + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + class AwaitableMock(mock.AsyncMock): + def __await__(self): + self.await_count += 1 + return iter([]) + + mock_object = AwaitableMock() + client._client._transport._wrapped_methods[ + client._client._transport.list_optimal_trials + ] = mock_object + + request = {} + await client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.list_optimal_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + @pytest.mark.asyncio async def test_list_optimal_trials_async( transport: str = "grpc_asyncio", @@ -5704,6 +7033,42 @@ def get_message_fields(field): assert response.inactive_reason == "inactive_reason_value" +def test_create_study_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_study] = mock_rpc + + request = {} + client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_study_rest_required_fields( request_type=vizier_service.CreateStudyRequest, ): @@ -5986,6 +7351,42 @@ def test_get_study_rest(request_type): assert response.inactive_reason == "inactive_reason_value" +def test_get_study_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_study] = mock_rpc + + request = {} + client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_study_rest_required_fields(request_type=vizier_service.GetStudyRequest): transport_class = transports.VizierServiceRestTransport @@ -6247,6 +7648,42 @@ def test_list_studies_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_studies_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_studies in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_studies] = mock_rpc + + request = {} + client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_studies(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_studies_rest_required_fields( request_type=vizier_service.ListStudiesRequest, ): @@ -6581,6 +8018,42 @@ def test_delete_study_rest(request_type): assert response is None +def test_delete_study_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_study] = mock_rpc + + request = {} + client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_study_rest_required_fields( request_type=vizier_service.DeleteStudyRequest, ): @@ -6841,6 +8314,42 @@ def test_lookup_study_rest(request_type): assert response.inactive_reason == "inactive_reason_value" +def test_lookup_study_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.lookup_study in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.lookup_study] = mock_rpc + + request = {} + client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.lookup_study(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_lookup_study_rest_required_fields( request_type=vizier_service.LookupStudyRequest, ): @@ -7114,6 +8623,46 @@ def test_suggest_trials_rest(request_type): assert response.operation.name == "operations/spam" +def test_suggest_trials_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.suggest_trials in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.suggest_trials] = mock_rpc + + request = {} + client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.suggest_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_suggest_trials_rest_required_fields( request_type=vizier_service.SuggestTrialsRequest, ): @@ -7448,6 +8997,42 @@ def get_message_fields(field): assert response.custom_job == "custom_job_value" +def test_create_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_trial] = mock_rpc + + request = {} + client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_create_trial_rest_required_fields( request_type=vizier_service.CreateTrialRequest, ): @@ -7738,6 +9323,42 @@ def test_get_trial_rest(request_type): assert response.custom_job == "custom_job_value" +def test_get_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_trial] = mock_rpc + + request = {} + client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_get_trial_rest_required_fields(request_type=vizier_service.GetTrialRequest): transport_class = transports.VizierServiceRestTransport @@ -8003,6 +9624,42 @@ def test_list_trials_rest(request_type): assert response.next_page_token == "next_page_token_value" +def test_list_trials_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_trials in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_trials] = mock_rpc + + request = {} + client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_trials_rest_required_fields( request_type=vizier_service.ListTrialsRequest, ): @@ -8356,6 +10013,47 @@ def test_add_trial_measurement_rest(request_type): assert response.custom_job == "custom_job_value" +def test_add_trial_measurement_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.add_trial_measurement + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.add_trial_measurement + ] = mock_rpc + + request = {} + client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.add_trial_measurement(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_add_trial_measurement_rest_required_fields( request_type=vizier_service.AddTrialMeasurementRequest, ): @@ -8587,6 +10285,42 @@ def test_complete_trial_rest(request_type): assert response.custom_job == "custom_job_value" +def test_complete_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.complete_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.complete_trial] = mock_rpc + + request = {} + client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.complete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_complete_trial_rest_required_fields( request_type=vizier_service.CompleteTrialRequest, ): @@ -8795,6 +10529,42 @@ def test_delete_trial_rest(request_type): assert response is None +def test_delete_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_trial] = mock_rpc + + request = {} + client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_delete_trial_rest_required_fields( request_type=vizier_service.DeleteTrialRequest, ): @@ -9050,6 +10820,51 @@ def test_check_trial_early_stopping_state_rest(request_type): assert response.operation.name == "operations/spam" +def test_check_trial_early_stopping_state_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.check_trial_early_stopping_state + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.check_trial_early_stopping_state + ] = mock_rpc + + request = {} + client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.check_trial_early_stopping_state(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_check_trial_early_stopping_state_rest_required_fields( request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, ): @@ -9277,6 +11092,42 @@ def test_stop_trial_rest(request_type): assert response.custom_job == "custom_job_value" +def test_stop_trial_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.stop_trial in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.stop_trial] = mock_rpc + + request = {} + client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stop_trial(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_stop_trial_rest_required_fields(request_type=vizier_service.StopTrialRequest): transport_class = transports.VizierServiceRestTransport @@ -9483,6 +11334,46 @@ def test_list_optimal_trials_rest(request_type): assert isinstance(response, vizier_service.ListOptimalTrialsResponse) +def test_list_optimal_trials_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VizierServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_optimal_trials in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_optimal_trials + ] = mock_rpc + + request = {} + client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_optimal_trials(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + def test_list_optimal_trials_rest_required_fields( request_type=vizier_service.ListOptimalTrialsRequest, ): From 3c3727b48ce4ba12bdaf36806cda4907a788d38e Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Fri, 3 May 2024 12:15:21 -0700 Subject: [PATCH 04/30] fix: Add DeprecationWarning to vertexai.preview predictive models SDK PiperOrigin-RevId: 630461931 --- vertexai/preview/_workflow/driver/remote.py | 6 +++++- vertexai/preview/_workflow/shared/constants.py | 18 ++++++++++++++++++ .../preview/_workflow/shared/model_utils.py | 10 +++++----- vertexai/preview/developer/__init__.py | 8 +++++++- .../preview/hyperparameter_tuning/__init__.py | 5 ++++- vertexai/preview/initializer.py | 8 ++++++-- vertexai/preview/tabular_models/__init__.py | 4 ++++ 7 files changed, 49 insertions(+), 10 deletions(-) diff --git a/vertexai/preview/_workflow/driver/remote.py b/vertexai/preview/_workflow/driver/remote.py index e8364e9221..b26a12b575 100644 --- a/vertexai/preview/_workflow/driver/remote.py +++ b/vertexai/preview/_workflow/driver/remote.py @@ -18,7 +18,7 @@ import abc import inspect from typing import Any, Callable, Dict, Optional, Type - +import warnings from vertexai.preview._workflow import driver from vertexai.preview._workflow.executor import ( training, @@ -27,6 +27,7 @@ any_serializer, ) from vertexai.preview._workflow.shared import ( + constants, supported_frameworks, ) from vertexai.preview.developer import remote_specs @@ -41,6 +42,9 @@ def remote_method_decorator( return driver.VertexRemoteFunctor(method, remote_executor, remote_executor_kwargs) +warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1) + + def remote_class_decorator(cls: Type) -> Type: """Add Vertex attributes to a class object.""" diff --git a/vertexai/preview/_workflow/shared/constants.py b/vertexai/preview/_workflow/shared/constants.py index 8cf36a2cca..4c7e5cd644 100644 --- a/vertexai/preview/_workflow/shared/constants.py +++ b/vertexai/preview/_workflow/shared/constants.py @@ -19,3 +19,21 @@ _START_EXECUTION_MSG = "Start remote execution on Vertex..." _END_EXECUTION_MSG = "Remote execution is completed." + +_V2_0_WARNING_MSG = """ +After May 30, 2024, importing any code below will result in an error. +Please verify that you are explicitly pinning to a version of `google-cloud-aiplatform` +(e.g., google-cloud-aiplatform==[1.32.0, 1.49.0]) if you need to continue using this +library. + +from vertexai.preview import ( + init, + remote, + VertexModel, + register, + from_pretrained, + developer, + hyperparameter_tuning, + tabular_models, +) +""" diff --git a/vertexai/preview/_workflow/shared/model_utils.py b/vertexai/preview/_workflow/shared/model_utils.py index 663a4740b3..f64a3fa25d 100644 --- a/vertexai/preview/_workflow/shared/model_utils.py +++ b/vertexai/preview/_workflow/shared/model_utils.py @@ -24,6 +24,7 @@ import os import re from typing import Any, Dict, Optional, Union +import warnings from google.cloud import aiplatform from google.cloud.aiplatform import base @@ -35,14 +36,12 @@ any_serializer, serializers_base, ) +from vertexai.preview._workflow.shared import constants # These need to be imported to be included in _ModelGardenModel.__init_subclass__ from vertexai.language_models import ( _language_models, ) # pylint:disable=unused-import -from vertexai.vision_models import ( - _vision_models, -) # pylint:disable=unused-import from vertexai._model_garden import _model_garden_models from google.cloud.aiplatform import _publisher_models from vertexai.preview._workflow.executor import training @@ -60,9 +59,10 @@ _OUTPUT_ESTIMATOR_DIR = "output_estimator" _OUTPUT_PREDICTIONS_DIR = "output_predictions" - _LOGGER = base.Logger("vertexai.remote_execution") +warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1) + def _get_model_file_from_image_uri(container_image_uri: str) -> str: """Gets the model file from the container image URI. @@ -121,7 +121,7 @@ def _generate_remote_job_output_path(base_gcs_dir: str) -> str: def _get_model_from_successful_custom_job( job_dir: str, -) -> Union["sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module"]: +) -> Union["sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module"]: # noqa: F821 serializer = any_serializer.AnySerializer() diff --git a/vertexai/preview/developer/__init__.py b/vertexai/preview/developer/__init__.py index 7df80244b5..0039dc8e04 100644 --- a/vertexai/preview/developer/__init__.py +++ b/vertexai/preview/developer/__init__.py @@ -15,17 +15,23 @@ # limitations under the License. # +import warnings from vertexai.preview._workflow.serialization_engine import ( any_serializer, ) from vertexai.preview._workflow.serialization_engine import ( serializers_base, ) -from vertexai.preview._workflow.shared import configs +from vertexai.preview._workflow.shared import ( + configs, + constants, +) from vertexai.preview.developer import mark from vertexai.preview.developer import remote_specs +warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1) + PersistentResourceConfig = configs.PersistentResourceConfig Serializer = serializers_base.Serializer SerializationMetadata = serializers_base.SerializationMetadata diff --git a/vertexai/preview/hyperparameter_tuning/__init__.py b/vertexai/preview/hyperparameter_tuning/__init__.py index 16a54be359..461acffc4b 100644 --- a/vertexai/preview/hyperparameter_tuning/__init__.py +++ b/vertexai/preview/hyperparameter_tuning/__init__.py @@ -15,11 +15,14 @@ # limitations under the License. # - +import warnings from vertexai.preview.hyperparameter_tuning import ( vizier_hyperparameter_tuner, ) +from vertexai.preview._workflow.shared import constants + +warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1) VizierHyperparameterTuner = vizier_hyperparameter_tuner.VizierHyperparameterTuner diff --git a/vertexai/preview/initializer.py b/vertexai/preview/initializer.py index 2d4babdab9..08e7b1dc54 100644 --- a/vertexai/preview/initializer.py +++ b/vertexai/preview/initializer.py @@ -14,13 +14,16 @@ # from typing import Optional - +import warnings from google.cloud import aiplatform from google.cloud.aiplatform import base from vertexai.preview._workflow.executor import ( persistent_resource_util, ) -from vertexai.preview._workflow.shared import configs +from vertexai.preview._workflow.shared import ( + configs, + constants, +) _LOGGER = base.Logger(__name__) @@ -30,6 +33,7 @@ class _Config: """Store common configurations and current workflow for remote execution.""" def __init__(self): + warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1) self._remote = False self._cluster = None diff --git a/vertexai/preview/tabular_models/__init__.py b/vertexai/preview/tabular_models/__init__.py index d96f82480a..5ee4607d17 100644 --- a/vertexai/preview/tabular_models/__init__.py +++ b/vertexai/preview/tabular_models/__init__.py @@ -15,10 +15,14 @@ # limitations under the License. # +import warnings +from vertexai.preview._workflow.shared import constants from vertexai.preview.tabular_models import tabnet_trainer +warnings.warn(constants._V2_0_WARNING_MSG, DeprecationWarning, stacklevel=1) + TabNetTrainer = tabnet_trainer.TabNetTrainer From 7279dab402cc10fedf3c5ecab5c6a0609e4b7e5b Mon Sep 17 00:00:00 2001 From: Yvonne Yu Date: Fri, 3 May 2024 12:17:33 -0700 Subject: [PATCH 05/30] feat: GenAI - Added `response_style` to `GenerationConfig` PiperOrigin-RevId: 630462525 --- .../cloud/aiplatform_v1beta1/types/content.py | 27 +++++++++++++++++++ .../generative_models/_generative_models.py | 5 ++++ 2 files changed, 32 insertions(+) diff --git a/google/cloud/aiplatform_v1beta1/types/content.py b/google/cloud/aiplatform_v1beta1/types/content.py index fbd449c5c6..8173e3baef 100644 --- a/google/cloud/aiplatform_v1beta1/types/content.py +++ b/google/cloud/aiplatform_v1beta1/types/content.py @@ -309,8 +309,30 @@ class GenerationConfig(proto.Message): The model needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. This is a preview feature. + + response_style (google.cloud.aiplatform_v1beta1.types.GenerationConfig.ResponseStyle): + Control Three levels of creativity in the model output. + Default: RESPONSE_STYLE_BALANCED """ + class ResponseStyle(proto.Enum): + r"""Choices of the response style. + + Values: + RESPONSE_STYLE_UNSPECIFIED (0): + response style unspecified. + RESPONSE_STYLE_PRECISE (1): + Precise response. + RESPONSE_STYLE_BALANCED (2): + Default response style. + RESPONSE_STYLE_CREATIVE (3): + Creative response style. + """ + RESPONSE_STYLE_UNSPECIFIED = 0 + RESPONSE_STYLE_PRECISE = 1 + RESPONSE_STYLE_BALANCED = 2 + RESPONSE_STYLE_CREATIVE = 3 + temperature: float = proto.Field( proto.FLOAT, number=1, @@ -354,6 +376,11 @@ class GenerationConfig(proto.Message): proto.STRING, number=13, ) + response_style: ResponseStyle = proto.Field( + proto.ENUM, + number=14, + enum=ResponseStyle, + ) class SafetySetting(proto.Message): diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index cd30e7a776..09222a34f1 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -1181,6 +1181,7 @@ class ResponseValidationError(ResponseBlockedError): class GenerationConfig: """Parameters for the generation.""" + ResponseStyle = gapic_content_types.GenerationConfig.ResponseStyle def __init__( self, @@ -1194,6 +1195,7 @@ def __init__( presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, response_mime_type: Optional[str] = None, + response_style: Optional["GenerationConfig.ResponseStyle"] = None, ): r"""Constructs a GenerationConfig object. @@ -1216,6 +1218,7 @@ def __init__( The model needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. + response_style: Control three levels of creativity in the model output. Usage: ``` @@ -1228,6 +1231,7 @@ def __init__( candidate_count=1, max_output_tokens=100, stop_sequences=["\n\n\n"], + response_style=ResponseStyle.RESPONSE_STYLE_PRECISE, ) ) ``` @@ -1242,6 +1246,7 @@ def __init__( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, response_mime_type=response_mime_type, + response_style=response_style, ) @classmethod From 510c8334961cdb6f801863ecbd8fe49bf69b6c68 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Fri, 3 May 2024 20:29:55 -0700 Subject: [PATCH 06/30] fix: Upload the reference model in model registry PiperOrigin-RevId: 630570083 --- google/cloud/aiplatform/models.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 734a9807ef..0b935fb166 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -3272,10 +3272,11 @@ def upload( if local_model: container_spec = local_model.get_serving_container_spec() appended_user_agent = [prediction_constants.CUSTOM_PREDICTION_ROUTINES] + elif not serving_container_image_uri and not artifact_uri: + # It's a referenced/place holder model. + container_spec = None else: - # Referenced models do not have container_image and artifact_uri - # Skip the container_image if this is a referenced model - if not serving_container_image_uri and artifact_uri: + if not serving_container_image_uri: raise ValueError( "The parameter `serving_container_image_uri` is required " "if no `local_model` is provided." @@ -5256,10 +5257,10 @@ def evaluate( the class "cat" corresponds with 0.97 in the example above. prediction_label_column (str): Optional. The column name of the field containing classes the model is scoring. Formatted to be able to find nested - columns, delimeted by `.`. If not set, defaulted to `prediction.classes` for classification. + columns, delimited by `.`. If not set, defaulted to `prediction.classes` for classification. prediction_score_column (str): Optional. The column name of the field containing batch prediction scores. Formatted to be able to find nested columns, - delimeted by `.`. If not set, defaulted to `prediction.scores` for a `classification` problem_type, `prediction.value` + delimited by `.`. If not set, defaulted to `prediction.scores` for a `classification` problem_type, `prediction.value` for a `regression` problem_type. staging_bucket (str): Optional. The GCS directory to use for staging files from this evaluation job. Defaults to the value set in From 262b184d1606a2099b705ff57324b1477f54e492 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Mon, 6 May 2024 09:38:51 -0700 Subject: [PATCH 07/30] chore: Add more cleanup for match engine index test. PiperOrigin-RevId: 631091852 --- .../aiplatform/test_matching_engine_index.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/system/aiplatform/test_matching_engine_index.py b/tests/system/aiplatform/test_matching_engine_index.py index 655a214284..113b48c06f 100644 --- a/tests/system/aiplatform/test_matching_engine_index.py +++ b/tests/system/aiplatform/test_matching_engine_index.py @@ -236,6 +236,15 @@ def test_create_get_list_matching_engine_index(self, shared_state): location=e2e_base._LOCATION, ) + # Clean up resources from previous test runs. + for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list(): + for deployed_index in index_endpoint.deployed_indexes: + index_endpoint.undeploy_index(deployed_index_id=deployed_index.id) + index_endpoint.delete() + + for index in aiplatform.MatchingEngineIndex.list(): + index.delete() + # Create an index index = aiplatform.MatchingEngineIndex.create_tree_ah_index( display_name=_TEST_INDEX_DISPLAY_NAME, @@ -446,6 +455,15 @@ def test_matching_engine_stream_index(self, shared_state): location=e2e_base._LOCATION, ) + # Clean up resources from previous test runs. + for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list(): + for deployed_index in index_endpoint.deployed_indexes: + index_endpoint.undeploy_index(deployed_index_id=deployed_index.id) + index_endpoint.delete() + + for index in aiplatform.MatchingEngineIndex.list(): + index.delete() + # Create an index stream_index = aiplatform.MatchingEngineIndex.create_tree_ah_index( display_name=_TEST_STREAM_INDEX_DISPLAY_NAME, From b0c5eda79489d4b32972b2acea647e3c8cdc3ce9 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Mon, 6 May 2024 10:07:17 -0700 Subject: [PATCH 08/30] feat: AutoSxS Pairwise Metric in Rapid Evaluation SDK PiperOrigin-RevId: 631101180 --- tests/unit/vertexai/test_evaluation.py | 258 +++++++++++++- vertexai/preview/evaluation/__init__.py | 4 +- vertexai/preview/evaluation/_base.py | 6 +- vertexai/preview/evaluation/_eval_tasks.py | 8 +- vertexai/preview/evaluation/_evaluation.py | 274 +++++++++++---- .../preview/evaluation/metrics/__init__.py | 2 + vertexai/preview/evaluation/metrics/_base.py | 52 ++- .../metrics/_instance_evaluation.py | 322 ++++++++++-------- 8 files changed, 707 insertions(+), 219 deletions(-) diff --git a/tests/unit/vertexai/test_evaluation.py b/tests/unit/vertexai/test_evaluation.py index c330506792..fefe234621 100644 --- a/tests/unit/vertexai/test_evaluation.py +++ b/tests/unit/vertexai/test_evaluation.py @@ -26,6 +26,7 @@ from google.cloud.aiplatform_v1beta1.types import ( evaluation_service as gapic_evaluation_service_types, ) +from vertexai import generative_models from vertexai.preview import evaluation from vertexai.preview.evaluation import utils import pandas as pd @@ -34,7 +35,7 @@ _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" -_TEST_METRICS = [ +_TEST_METRICS = ( "exact_match", "bleu", "rouge_1", @@ -53,7 +54,7 @@ "question_answering_relevance", "question_answering_helpfulness", "question_answering_correctness", -] +) _TEST_EVAL_DATASET = pd.DataFrame( { "response": ["test", "text"], @@ -78,7 +79,7 @@ """ -_MOCK_EXACT_MATCH_RESULT = [ +_MOCK_EXACT_MATCH_RESULT = ( gapic_evaluation_service_types.EvaluateInstancesResponse( exact_match_results=gapic_evaluation_service_types.ExactMatchResults( exact_match_metric_values=[ @@ -93,9 +94,9 @@ ] ) ), -] +) -_MOCK_FLUENCY_RESULT = [ +_MOCK_FLUENCY_RESULT = ( gapic_evaluation_service_types.EvaluateInstancesResponse( fluency_result=gapic_evaluation_service_types.FluencyResult( score=5, explanation="explanation", confidence=1.0 @@ -106,7 +107,34 @@ score=4, explanation="explanation", confidence=0.5 ) ), -] +) + +_MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT = ( + gapic_evaluation_service_types.EvaluateInstancesResponse( + pairwise_summarization_quality_result=gapic_evaluation_service_types.PairwiseSummarizationQualityResult( + pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.BASELINE, + explanation="explanation", + confidence=1.0, + ) + ), + gapic_evaluation_service_types.EvaluateInstancesResponse( + pairwise_summarization_quality_result=gapic_evaluation_service_types.PairwiseSummarizationQualityResult( + pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.CANDIDATE, + explanation="explanation", + confidence=0.5, + ) + ), +) + +_MOCK_MODEL_INFERENCE_RESPONSE = generative_models.GenerationResponse.from_dict( + { + "candidates": [ + { + "content": {"parts": [{"text": "test_response"}]}, + } + ] + } +) @pytest.fixture @@ -260,6 +288,188 @@ def test_compute_pointwise_metrics(self, api_transport): 0.5, ] + @pytest.mark.parametrize("api_transport", ["grpc", "rest"]) + def test_compute_pairwise_metrics_with_model_inference(self, api_transport): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + api_transport=api_transport, + ) + eval_dataset = pd.DataFrame( + { + "context": ["test", "context"], + "instruction": ["test", "instruction"], + } + ) + mock_baseline_model = mock.create_autospec( + generative_models.GenerativeModel, instance=True + ) + mock_baseline_model.generate_content.return_value = ( + _MOCK_MODEL_INFERENCE_RESPONSE + ) + mock_baseline_model._model_name = "publishers/google/model/gemini-pro" + mock_candidate_model = mock.create_autospec( + generative_models.GenerativeModel, instance=True + ) + mock_candidate_model.generate_content.return_value = ( + _MOCK_MODEL_INFERENCE_RESPONSE + ) + mock_candidate_model._model_name = "publishers/google/model/gemini-pro" + test_metrics = [ + evaluation.PairwiseMetric( + metric="summarization_quality", + baseline_model=mock_baseline_model, + use_reference=False, + ) + ] + test_eval_task = evaluation.EvalTask(dataset=eval_dataset, metrics=test_metrics) + mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT + with mock.patch.object( + target=gapic_evaluation_services.EvaluationServiceAsyncClient, + attribute="evaluate_instances", + side_effect=mock_metric_results, + ): + test_result = test_eval_task.evaluate( + model=mock_candidate_model, + prompt_template="{instruction} test prompt template {context}", + ) + + assert test_result.summary_metrics["row_count"] == 2 + assert set(test_result.metrics_table.columns.values) == set( + [ + "context", + "instruction", + "completed_prompt", + "response", + "baseline_model_response", + "pairwise_summarization_quality/pairwise_choice", + "pairwise_summarization_quality/explanation", + "pairwise_summarization_quality/confidence", + ] + ) + assert list( + test_result.metrics_table[ + "pairwise_summarization_quality/pairwise_choice" + ].values + ) == ["BASELINE", "CANDIDATE"] + assert list( + test_result.metrics_table[ + "pairwise_summarization_quality/explanation" + ].values + ) == [ + "explanation", + "explanation", + ] + assert list( + test_result.metrics_table[ + "pairwise_summarization_quality/confidence" + ].values + ) == [ + 1.0, + 0.5, + ] + assert set(test_result.summary_metrics.keys()) == set( + [ + "row_count", + "pairwise_summarization_quality/candidate_model_win_rate", + "pairwise_summarization_quality/baseline_model_win_rate", + ] + ) + assert ( + test_result.summary_metrics[ + "pairwise_summarization_quality/candidate_model_win_rate" + ] + == 0.5 + ) + assert ( + test_result.summary_metrics[ + "pairwise_summarization_quality/baseline_model_win_rate" + ] + == 0.5 + ) + + @pytest.mark.parametrize("api_transport", ["grpc", "rest"]) + def test_compute_pairwise_metrics_without_inference(self, api_transport): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + api_transport=api_transport, + ) + eval_dataset = pd.DataFrame( + { + "response": ["test", "text"], + "baseline_model_response": ["baseline", "response"], + "reference": ["test", "reference"], + } + ) + test_metrics = [ + evaluation.PairwiseMetric( + metric="summarization_quality", + baseline_model=None, + use_reference=True, + ) + ] + test_eval_task = evaluation.EvalTask(dataset=eval_dataset, metrics=test_metrics) + mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT + with mock.patch.object( + target=gapic_evaluation_services.EvaluationServiceAsyncClient, + attribute="evaluate_instances", + side_effect=mock_metric_results, + ): + test_result = test_eval_task.evaluate() + + assert test_result.summary_metrics["row_count"] == 2 + assert set(test_result.metrics_table.columns.values) == set( + [ + "response", + "baseline_model_response", + "reference", + "pairwise_summarization_quality/pairwise_choice", + "pairwise_summarization_quality/explanation", + "pairwise_summarization_quality/confidence", + ] + ) + assert list( + test_result.metrics_table[ + "pairwise_summarization_quality/pairwise_choice" + ].values + ) == ["BASELINE", "CANDIDATE"] + assert list( + test_result.metrics_table[ + "pairwise_summarization_quality/explanation" + ].values + ) == [ + "explanation", + "explanation", + ] + assert list( + test_result.metrics_table[ + "pairwise_summarization_quality/confidence" + ].values + ) == [ + 1.0, + 0.5, + ] + assert set(test_result.summary_metrics.keys()) == set( + [ + "row_count", + "pairwise_summarization_quality/candidate_model_win_rate", + "pairwise_summarization_quality/baseline_model_win_rate", + ] + ) + assert ( + test_result.summary_metrics[ + "pairwise_summarization_quality/candidate_model_win_rate" + ] + == 0.5 + ) + assert ( + test_result.summary_metrics[ + "pairwise_summarization_quality/baseline_model_win_rate" + ] + == 0.5 + ) + @pytest.mark.usefixtures("google_auth_mock") class TestEvaluationErrors: @@ -325,6 +535,42 @@ def test_evaluate_invalid_prompt_template_placeholder(self): prompt_template="test_prompt_template {invalid_placeholder}", ) + def test_evaluate_pairwise_metrics_with_multiple_baseline_models(self): + eval_dataset = pd.DataFrame( + { + "context": ["test", "context"], + "instruction": ["test", "instruction"], + } + ) + mock_baseline_model_1 = mock.create_autospec( + generative_models.GenerativeModel, instance=True + ) + mock_baseline_model_1._model_name = "publishers/google/model/gemini-1.0-pro" + mock_baseline_model_2 = mock.create_autospec( + generative_models.GenerativeModel, instance=True + ) + mock_baseline_model_2._model_name = "publishers/google/model/gemini-1.5-pro" + mock_candidate_model = mock.create_autospec( + generative_models.GenerativeModel, instance=True + ) + mock_candidate_model._model_name = "publishers/google/model/gemini-1.0-ultra" + test_metrics = [ + evaluation.PairwiseMetric( + metric="summarization_quality", + baseline_model=mock_baseline_model_1, + ), + evaluation.PairwiseMetric( + metric="summarization_quality", + baseline_model=mock_baseline_model_2, + ), + ] + test_eval_task = evaluation.EvalTask(dataset=eval_dataset, metrics=test_metrics) + with pytest.raises( + ValueError, + match="Not all PairwiseMetric instances have the same baseline_model", + ): + test_eval_task.evaluate(model=mock_candidate_model) + @pytest.mark.usefixtures("google_auth_mock") class TestEvaluationUtils: diff --git a/vertexai/preview/evaluation/__init__.py b/vertexai/preview/evaluation/__init__.py index 67895b4377..9b2f9b2218 100644 --- a/vertexai/preview/evaluation/__init__.py +++ b/vertexai/preview/evaluation/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Rapid GenAI Evaluation Module.""" +"""GenAI Rapid Evaluation Module.""" from vertexai.preview.evaluation import _base from vertexai.preview.evaluation import _eval_tasks @@ -25,11 +25,13 @@ EvalResult = _base.EvalResult EvalTask = _eval_tasks.EvalTask CustomMetric = metrics.CustomMetric +PairwiseMetric = metrics.PairwiseMetric make_metric = metrics.make_metric PromptTemplate = prompt_template.PromptTemplate __all__ = [ "CustomMetric", + "PairwiseMetric", "EvalResult", "EvalTask", "make_metric", diff --git a/vertexai/preview/evaluation/_base.py b/vertexai/preview/evaluation/_base.py index 588e1e6eac..0fe52f6e30 100644 --- a/vertexai/preview/evaluation/_base.py +++ b/vertexai/preview/evaluation/_base.py @@ -37,14 +37,14 @@ class EvaluationRunConfig: Attributes: dataset: The dataset to evaluate. - metrics: The list of metric names to evaluate, or a metrics bundle for an - evaluation task, or custom metric instances. + metrics: The list of metric names, or metric bundle names, or + CustomMetric instances, or PairwiseMetric instances to evaluate. column_map: The dictionary of column name overrides in the dataset. client: The asynchronous evaluation client. """ dataset: "pd.DataFrame" - metrics: List[Union[str, metrics_base.CustomMetric]] + metrics: List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]] column_map: Dict[str, str] client: gapic_evaluation_services.EvaluationServiceAsyncClient diff --git a/vertexai/preview/evaluation/_eval_tasks.py b/vertexai/preview/evaluation/_eval_tasks.py index cf86d12710..651ec127fc 100644 --- a/vertexai/preview/evaluation/_eval_tasks.py +++ b/vertexai/preview/evaluation/_eval_tasks.py @@ -76,7 +76,7 @@ class EvalTask: Metrics Details: The supported metrics, metric bundle descriptions, grading rubrics, and the required input fields can be found on the Vertex AI public - documentation. + documentation page [Evaluation methods and metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval). Usage: 1. To perform bring your own prediction evaluation, provide the model @@ -205,6 +205,7 @@ def __init__( "tool_call_quality", ], metrics_base.CustomMetric, + metrics_base.PairwiseMetric, ] ], experiment: Optional[str] = None, @@ -225,8 +226,9 @@ def __init__( (e.g., 'gs://bucket/data.csv'). * BigQuery table URI: Loaded from Google Cloud BigQuery (e.g., 'bq://project-id.dataset.table_name'). - metrics: The list of metrics names to be evaluated, or a metrics - bundle for an evaluation task, or custom metric instances. + metrics: The list of metric names, or metric bundle names, or + CustomMetric instances, or PairwiseMetric instances to evaluate. + Prompt template is required for PairwiseMetric. experiment: The name of the experiment to log the evaluations to. content_column_name: The column name of content in the dataset to send to the model. If not set, default to `content`. diff --git a/vertexai/preview/evaluation/_evaluation.py b/vertexai/preview/evaluation/_evaluation.py index 1c24664060..476f9663e1 100644 --- a/vertexai/preview/evaluation/_evaluation.py +++ b/vertexai/preview/evaluation/_evaluation.py @@ -88,19 +88,19 @@ def _replace_metric_bundle_with_metrics( - metrics_list: List[Union[str, metrics_base.CustomMetric]], -) -> List[str]: + metrics: List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]], +) -> List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]]: """Replaces metric bundles with corresponding metrics. Args: - metrics_list: The original list containing metrics bundle names. + metrics: The original metrics list containing metric bundle names. Returns: - The modified metrics list containing only metric names. + The modified metrics list with metric bundle names replaced. """ modified_list = [] - for item in metrics_list: + for item in metrics: if item in _METRICS_BUNDLE_TO_METRIC_NAMES.keys(): modified_list.extend(_METRICS_BUNDLE_TO_METRIC_NAMES[item]) else: @@ -144,8 +144,10 @@ def _compute_custom_metrics( def _separate_custom_metrics( - metrics: List[str], -) -> Tuple[List[str], List[metrics_base.CustomMetric],]: + metrics: List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]], +) -> Tuple[ + List[Union[str, metrics_base.PairwiseMetric]], List[metrics_base.CustomMetric] +]: """Separates the metrics list into API and custom metrics.""" custom_metrics = [] api_metrics = [] @@ -174,17 +176,32 @@ def _compute_summary_metrics( summary_metrics[constants.MetricResult.ROW_COUNT_KEY] = metrics_table.shape[0] for metric in evaluation_run_config.metrics: try: - # TODO(b/325078638): implement additional aggregate methods. - summary_metrics[f"{str(metric)}/mean"] = metrics_table.loc[ - :, str(metric) - ].mean() - summary_metrics[f"{str(metric)}/std"] = metrics_table.loc[ - :, str(metric) - ].std() - except (ValueError, KeyError): + if isinstance(metric, metrics_base.PairwiseMetric): + summary_metrics[ + f"{metric.pairwise_metric_name}/candidate_model_win_rate" + ] = ( + metrics_table[f"{metric.pairwise_metric_name}/pairwise_choice"] + == "CANDIDATE" + ).mean() + summary_metrics[ + f"{metric.pairwise_metric_name}/baseline_model_win_rate" + ] = ( + metrics_table[f"{metric.pairwise_metric_name}/pairwise_choice"] + == "BASELINE" + ).mean() + else: + # TODO(b/325078638): implement additional aggregate methods. + summary_metrics[f"{str(metric)}/mean"] = metrics_table.loc[ + :, str(metric) + ].mean() + summary_metrics[f"{str(metric)}/std"] = metrics_table.loc[ + :, str(metric) + ].std() + except (ValueError, KeyError) as e: _LOGGER.warning( f"Failed to compute metric statistics for {metric}. This metric" - " output contains error from the Autorater." + " output contains error from the Autorater.\n" + f"{type(e).__name__}: {e}" ) continue return summary_metrics @@ -193,7 +210,7 @@ def _compute_summary_metrics( def _generate_response_from_gemini( model: generative_models.GenerativeModel, prompt: str ) -> str: - """Generates response from Gemini model. + """Generates a text response from Gemini model from a text prompt. Args: model: The Gemini model instance. @@ -211,7 +228,7 @@ def _generate_response_from_gemini( if not response.candidates: raise RuntimeError( f"The model response was blocked due to {response._raw_response.prompt_feedback.block_reason.name}.\n" - f"Blocke reason message: {response._raw_response.prompt_feedback.block_reason_message}.\n" + f"Blocked reason message: {response._raw_response.prompt_feedback.block_reason_message}.\n" "The input prompt may be blocked for safety reasons.", f"Prompt: {prompt}.", ) @@ -228,7 +245,7 @@ def _generate_response_from_gemini( return response.candidates[0].content.parts[0].text except Exception: raise RuntimeError( - "Failed to generate response candidates from Gemini model.\n" + f"Failed to generate response candidates from Gemini model {model._model_name}.\n" f"Response: {response}.\n" f"Prompt: {prompt}." ) @@ -237,19 +254,30 @@ def _generate_response_from_gemini( def _generate_response_from_gemini_model( model: generative_models.GenerativeModel, evaluation_run_config: evaluation_base.EvaluationRunConfig, + is_baseline_model: bool = False, ) -> None: """Generates responses from Gemini model. Args: model: The Gemini model instance. evaluation_run_config: Evaluation Run Configurations. + is_baseline_model: Whether the model is a baseline model for PairwiseMetric. """ + response_column_name = ( + constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN + if is_baseline_model + else constants.Dataset.MODEL_RESPONSE_COLUMN + ) + _LOGGER.info( + f"Generating a total of {evaluation_run_config.dataset.shape[0]} " + f"responses from Gemini model {model._model_name.split('/')[-1]}." + ) if ( constants.Dataset.COMPLETED_PROMPT_COLUMN in evaluation_run_config.dataset.columns ): evaluation_run_config.dataset[ - constants.Dataset.MODEL_RESPONSE_COLUMN + response_column_name ] = evaluation_run_config.dataset[ constants.Dataset.COMPLETED_PROMPT_COLUMN ].apply( @@ -257,31 +285,47 @@ def _generate_response_from_gemini_model( ) else: evaluation_run_config.dataset[ - constants.Dataset.MODEL_RESPONSE_COLUMN + response_column_name ] = evaluation_run_config.dataset[ evaluation_run_config.column_map[constants.Dataset.CONTENT_COLUMN] ].apply( lambda x: _generate_response_from_gemini(model, x) ) + _LOGGER.info( + f"All {evaluation_run_config.dataset.shape[0]} responses are successfully" + f" generated from Gemini model {model._model_name.split('/')[-1]}." + ) def _generate_response_from_custom_model_fn( model_fn: Callable[[str], str], evaluation_run_config: evaluation_base.EvaluationRunConfig, + is_baseline_model: bool = False, ) -> None: """Generates responses from a custom model function. Args: model_fn: The custom model function. evaluation_run_config: Evaluation Run Configurations. + is_baseline_model: Whether the model is a baseline model for + PairwiseMetric. """ + response_column_name = ( + constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN + if is_baseline_model + else constants.Dataset.MODEL_RESPONSE_COLUMN + ) + _LOGGER.info( + f"Generating a total of {evaluation_run_config.dataset.shape[0]} " + "responses from the custom model function." + ) try: if ( constants.Dataset.COMPLETED_PROMPT_COLUMN in evaluation_run_config.dataset.columns ): evaluation_run_config.dataset[ - constants.Dataset.MODEL_RESPONSE_COLUMN + response_column_name ] = evaluation_run_config.dataset[ constants.Dataset.COMPLETED_PROMPT_COLUMN ].apply( @@ -289,7 +333,7 @@ def _generate_response_from_custom_model_fn( ) else: evaluation_run_config.dataset[ - constants.Dataset.MODEL_RESPONSE_COLUMN + response_column_name ] = evaluation_run_config.dataset[ evaluation_run_config.column_map[constants.Dataset.CONTENT_COLUMN] ].apply( @@ -298,6 +342,11 @@ def _generate_response_from_custom_model_fn( except (ValueError, IndexError) as e: _LOGGER.warning(f"Failed to generate response from model function: {e}") + _LOGGER.info( + f"All {evaluation_run_config.dataset.shape[0]} responses are successfully" + " generated from the custom model function." + ) + def _check_placeholder_columns_exist( dataset: "pd.DataFrame", placeholder_names_set: Set[str] @@ -322,14 +371,18 @@ def _check_placeholder_columns_exist( def _complete_prompt_for_dataset( - evaluation_run_config: evaluation_base.EvaluationRunConfig, prompt_template: str + evaluation_run_config: evaluation_base.EvaluationRunConfig, + prompt_template: Union[prompt_template_base.PromptTemplate, str], ) -> None: """Adds a column in dataset for completed prompts from placeholder columns. Args: evaluation_run_config: Evaluation Run Configurations. - prompt_template: A prompt template string with placeholders that can be - formatted with dataset columns. + prompt_template: A `PromptTemplate` object or a prompt template string + with placeholders that can be assembled from the evaluation dataset. The + placeholders can be represented in curly braces `{placeholder}`, and + must be included in the dataset columns if specified. The placeholder + names cannot contain spaces. Returns: The completed prompt template string to send to the model. @@ -338,7 +391,16 @@ def _complete_prompt_for_dataset( ValueError: If any placeholder names do not exist in the dataset columns or the prompt template is invalid. """ - prompt_template = prompt_template_base.PromptTemplate(prompt_template) + if not prompt_template: + raise ValueError("Prompt template cannot be an empty string.") + + _LOGGER.info( + 'Completing prompts from the prompt_template. The "completed_prompt" ' + "column in the EvalResult.metrics_table has the completed prompts " + "used for model content generation." + ) + if isinstance(prompt_template, str): + prompt_template = prompt_template_base.PromptTemplate(prompt_template) _check_placeholder_columns_exist( evaluation_run_config.dataset, prompt_template.placeholders ) @@ -381,9 +443,6 @@ def _parse_metric_results_to_dataframe( metrics_table = pd.DataFrame(dict(zip(instance_df.columns, instance_df.values.T))) for metric_name, metric_results in results.items(): - scores = [ - result.get(constants.MetricResult.SCORE_KEY) for result in metric_results - ] if ( metric_name in constants.Metric.MODEL_BASED_METRIC_LIST @@ -403,8 +462,24 @@ def _parse_metric_results_to_dataframe( metrics_table[ f"{metric_name}/{constants.MetricResult.CONFIDENCE_KEY}" ] = confidences - - metrics_table[metric_name] = scores + if metric_name in constants.Metric.PAIRWISE_METRIC_LIST: + pairwise_choices = [ + result.get(constants.MetricResult.PAIRWISE_CHOICE_KEY) + for result in metric_results + ] + metrics_table[ + f"{metric_name}/{constants.MetricResult.PAIRWISE_CHOICE_KEY}" + ] = pairwise_choices + if ( + metric_name + in constants.Metric.AUTOMATIC_METRIC_LIST + + constants.Metric.MODEL_BASED_METRIC_LIST + ): + scores = [ + result.get(constants.MetricResult.SCORE_KEY) + for result in metric_results + ] + metrics_table[metric_name] = scores return metrics_table @@ -441,19 +516,28 @@ async def _compute_metrics( instance_list.append(row_dict) - for metric_name in api_metrics: + for metric in api_metrics: task = asyncio.create_task( _instance_evaluation.evaluate_instances_async( client=evaluation_run_config.client, request=_instance_evaluation.build_request( - metric_name=metric_name, + metric=metric, row_dict=row_dict, evaluation_run_config=evaluation_run_config, ), ) ) + if isinstance(metric, metrics_base.PairwiseMetric): + metric_name = metric.pairwise_metric_name + else: + metric_name = metric tasks_by_metric[metric_name].append(task) + api_request_count = len(tasks_by_metric) * len(next(iter(tasks_by_metric.values()))) + _LOGGER.info( + f"Computing metrics with a total of {api_request_count} Vertex online" + " evaluation service requests." + ) results_dict = { metric_name: await asyncio.gather(*tasks) for metric_name, tasks in tasks_by_metric.items() @@ -462,18 +546,48 @@ async def _compute_metrics( instance_df = pd.DataFrame.from_dict(instance_list) metrics_table = _parse_metric_results_to_dataframe(instance_df, results_dict) + _LOGGER.info(f"All {api_request_count} metrics are successfully computed.") summary_metrics = _compute_summary_metrics(evaluation_run_config, metrics_table) return summary_metrics, metrics_table +def _run_model_inference( + model: Union[generative_models.GenerativeModel, Callable[[str], str]], + evaluation_run_config: evaluation_base.EvaluationRunConfig, + is_baseline_model: bool = False, +) -> None: + """Runs model inference on dataset for evaluation. + + Args: + model: The GenerativeModel instance or a custom model function to generate + responses to evaluate. If not provided, the evaluation is computed with + the `response` column in the `dataset`. + evaluation_run_config: Evaluation Run Configurations. + is_baseline_model: Whether the model is a baseline model for PairwiseMetric. + + Raises: + ValueError: If the model or baseline model is not supported. + """ + if isinstance(model, generative_models.GenerativeModel): + _generate_response_from_gemini_model( + model, evaluation_run_config, is_baseline_model + ) + elif callable(model): + _generate_response_from_custom_model_fn( + model, evaluation_run_config, is_baseline_model + ) + else: + raise ValueError(f"Unsupported model or baseline model type: {type(model)}") + + def evaluate( dataset: "pd.DataFrame", - metrics: List[Union[str, metrics_base.CustomMetric]], + metrics: List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]], *, model: Optional[ Union[generative_models.GenerativeModel, Callable[[str], str]] ] = None, - prompt_template: Optional[str] = None, + prompt_template: Optional[Union[str, prompt_template_base.PromptTemplate]] = None, content_column_name: str = "content", reference_column_name: str = "reference", response_column_name: str = "response", @@ -484,16 +598,17 @@ def evaluate( Args: dataset: The dataset to evaluate. - metrics: The list of metrics names to evaluate, or a metrics bundle for an - evaluation task, or custom metric instances. + metrics: The list of metric names, or metric bundle names, or CustomMetric + instances, or PairwiseMetric instances to evaluate. Prompt template is + required for PairwiseMetric. model: The GenerativeModel instance or a custom model function to generate responses to evaluate. If not provided, the evaluation is computed with the `response` column in the `dataset`. - prompt_template: A prompt template string compatible with `PromptTemplate` - class with placeholders that can be formatted with dataset columns to - create completed prompts. The placeholders can be represented in curly - braces `{placeholder}`, and must be included in the dataset columns if - specified. The placeholder names cannot contain spaces. + prompt_template: A `PromptTemplate` or a prompt template string compatible + with `PromptTemplate` class with placeholders that can be formatted with + dataset columns to create completed prompts. The placeholders can be + represented in curly braces `{placeholder}`, and must be included in the + dataset columns if specified. The placeholder names cannot contain spaces. content_column_name: The column name of content in the dataset to send to the model. If not set, default to `content`. reference_column_name: The column name of ground truth in the dataset. If @@ -508,6 +623,11 @@ class with placeholders that can be formatted with dataset columns to Returns: EvalResult with summary metrics and a metrics table for per-instance metrics. + + Raises: + ValueError: If the metrics list is empty, or the prompt template is not + provided for PairwiseMetric, or multiple baseline models are specified for + PairwiseMetric instances. """ if not metrics: @@ -526,33 +646,61 @@ class with placeholders that can be formatted with dataset columns to client=utils.create_evaluation_service_async_client(), ) + if set(evaluation_run_config.metrics).intersection( + set(constants.Metric.AUTOMATIC_METRIC_LIST) + ): + evaluation_run_config.validate_dataset_column( + constants.Dataset.REFERENCE_COLUMN + ) + + baseline_model = None + pairwise_metric_exists = any( + isinstance(metric, metrics_base.PairwiseMetric) + for metric in evaluation_run_config.metrics + ) + if pairwise_metric_exists: + pairwise_metric_instances = [ + metric + for metric in evaluation_run_config.metrics + if isinstance(metric, metrics_base.PairwiseMetric) + ] + if ( + len(set(instance.baseline_model for instance in pairwise_metric_instances)) + > 1 + ): + # TODO(b/330598854): support multiple baseline models to compare + # with the candidate model. + raise ValueError( + "Not all PairwiseMetric instances have the same baseline_model" + ) + baseline_model = pairwise_metric_instances[0].baseline_model + if prompt_template: _complete_prompt_for_dataset(evaluation_run_config, prompt_template) + evaluation_run_config.validate_dataset_column( + constants.Dataset.COMPLETED_PROMPT_COLUMN + ) + elif baseline_model: + raise ValueError("Prompt template is required for computing PairwiseMetric.") + elif model: + evaluation_run_config.validate_dataset_column(constants.Dataset.CONTENT_COLUMN) if model: - if prompt_template: - evaluation_run_config.validate_dataset_column( - constants.Dataset.COMPLETED_PROMPT_COLUMN - ) - else: - evaluation_run_config.validate_dataset_column( - constants.Dataset.CONTENT_COLUMN - ) + _run_model_inference(model, evaluation_run_config) + evaluation_run_config.validate_dataset_column( + constants.Dataset.MODEL_RESPONSE_COLUMN + ) - if isinstance(model, generative_models.GenerativeModel): - _generate_response_from_gemini_model(model, evaluation_run_config) - elif callable(model): - _generate_response_from_custom_model_fn(model, evaluation_run_config) - else: + if baseline_model: + _run_model_inference( + model=baseline_model, + evaluation_run_config=evaluation_run_config, + is_baseline_model=True, + ) + if pairwise_metric_exists: evaluation_run_config.validate_dataset_column( - constants.Dataset.MODEL_RESPONSE_COLUMN + constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN ) - if set(evaluation_run_config.metrics).intersection( - set(constants.Metric.AUTOMATIC_METRIC_LIST) - ): - evaluation_run_config.validate_dataset_column( - constants.Dataset.REFERENCE_COLUMN - ) if asyncio.get_event_loop().is_running(): asyncio.set_event_loop(asyncio.new_event_loop()) diff --git a/vertexai/preview/evaluation/metrics/__init__.py b/vertexai/preview/evaluation/metrics/__init__.py index 94d768a030..ef6c78fe30 100644 --- a/vertexai/preview/evaluation/metrics/__init__.py +++ b/vertexai/preview/evaluation/metrics/__init__.py @@ -21,9 +21,11 @@ ) CustomMetric = _base.CustomMetric +PairwiseMetric = _base.PairwiseMetric make_metric = _base.make_metric __all__ = [ "CustomMetric", + "PairwiseMetric", "make_metric", ] diff --git a/vertexai/preview/evaluation/metrics/_base.py b/vertexai/preview/evaluation/metrics/_base.py index 35ab69aec5..f2cb85b784 100644 --- a/vertexai/preview/evaluation/metrics/_base.py +++ b/vertexai/preview/evaluation/metrics/_base.py @@ -15,7 +15,57 @@ # limitations under the License. # -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Literal, Optional, Union +from vertexai import generative_models + + +class PairwiseMetric: + """The Side-by-side(SxS) Pairwise Metric.""" + + def __init__( + self, + *, + metric: Literal["summarization_quality", "question_answering_quality"], + baseline_model: Union[ + generative_models.GenerativeModel, Callable[[str], str] + ] = None, + use_reference: bool = False, + version: Optional[int] = None, + ): + """Initializes the Side-by-side(SxS) Pairwise evaluation metric. + + Args: + metric: The Side-by-side(SxS) pairwise evaluation metric name. + baseline_model: The baseline model for the Side-by-side(SxS) comparison. + use_reference: Whether to use reference to compute the metric. If specified, + the reference column is required in the dataset. + version: The metric version to use for evaluation. + """ + self._metric = metric + self._baseline_model = baseline_model + self._use_reference = use_reference + self._version = version + + def __str__(self): + return self.pairwise_metric_name + + @property + def pairwise_metric_name(self) -> str: + return f"pairwise_{self._metric}" + + @property + def baseline_model( + self, + ) -> Union[generative_models.GenerativeModel, Callable[[str], str]]: + return self._baseline_model + + @property + def use_reference(self) -> bool: + return self._use_reference + + @property + def version(self) -> int: + return self._version class CustomMetric: diff --git a/vertexai/preview/evaluation/metrics/_instance_evaluation.py b/vertexai/preview/evaluation/metrics/_instance_evaluation.py index dc7a8ddf2b..d393684ed2 100644 --- a/vertexai/preview/evaluation/metrics/_instance_evaluation.py +++ b/vertexai/preview/evaluation/metrics/_instance_evaluation.py @@ -16,7 +16,7 @@ # """Library for Metrics Computation with Evaluation Service Async Client.""" -from typing import Any, Dict +from typing import Any, Dict, Union from google import api_core from google.cloud.aiplatform import base @@ -25,88 +25,78 @@ evaluation_service as gapic_evaluation_services, ) from google.cloud.aiplatform_v1beta1.types import ( - evaluation_service as gapic_evaluation_service_types, -) -from vertexai.preview.evaluation import ( - _base as eval_base, + evaluation_service as gapic_eval_service_types, ) +from vertexai.preview.evaluation import _base as eval_base from vertexai.preview.evaluation import constants +from vertexai.preview.evaluation.metrics import ( + _base as metrics_base, +) from google.protobuf import json_format + _LOGGER = base.Logger(__name__) _METRIC_NAME_TO_METRIC_SPEC = { # Automatic Metrics. - constants.Metric.EXACT_MATCH: (gapic_evaluation_service_types.ExactMatchSpec()), - constants.Metric.BLEU: gapic_evaluation_service_types.BleuSpec(), - constants.Metric.ROUGE_1: gapic_evaluation_service_types.RougeSpec( - rouge_type="rouge1" - ), - constants.Metric.ROUGE_2: gapic_evaluation_service_types.RougeSpec( - rouge_type="rouge2" - ), - constants.Metric.ROUGE_L: gapic_evaluation_service_types.RougeSpec( - rouge_type="rougeL" - ), - constants.Metric.ROUGE_L_SUM: gapic_evaluation_service_types.RougeSpec( + constants.Metric.EXACT_MATCH: (gapic_eval_service_types.ExactMatchSpec()), + constants.Metric.BLEU: gapic_eval_service_types.BleuSpec(), + constants.Metric.ROUGE_1: gapic_eval_service_types.RougeSpec(rouge_type="rouge1"), + constants.Metric.ROUGE_2: gapic_eval_service_types.RougeSpec(rouge_type="rouge2"), + constants.Metric.ROUGE_L: gapic_eval_service_types.RougeSpec(rouge_type="rougeL"), + constants.Metric.ROUGE_L_SUM: gapic_eval_service_types.RougeSpec( rouge_type="rougeLsum" ), - constants.Metric.TOOL_CALL_VALID: ( - gapic_evaluation_service_types.ToolCallValidSpec() - ), - constants.Metric.TOOL_NAME_MATCH: ( - gapic_evaluation_service_types.ToolNameMatchSpec() - ), + constants.Metric.TOOL_CALL_VALID: (gapic_eval_service_types.ToolCallValidSpec()), + constants.Metric.TOOL_NAME_MATCH: (gapic_eval_service_types.ToolNameMatchSpec()), constants.Metric.TOOL_PARAMETER_KV_MATCH: ( - gapic_evaluation_service_types.ToolParameterKVMatchSpec() + gapic_eval_service_types.ToolParameterKVMatchSpec() ), constants.Metric.TOOL_PARAMETER_KEY_MATCH: ( - gapic_evaluation_service_types.ToolParameterKeyMatchSpec() + gapic_eval_service_types.ToolParameterKeyMatchSpec() ), # Model-based Pointwise Metrics. - constants.Metric.FLUENCY: gapic_evaluation_service_types.FluencySpec(), - constants.Metric.COHERENCE: gapic_evaluation_service_types.CoherenceSpec(), - constants.Metric.SAFETY: gapic_evaluation_service_types.SafetySpec(), - constants.Metric.GROUNDEDNESS: (gapic_evaluation_service_types.GroundednessSpec()), - constants.Metric.FULFILLMENT: (gapic_evaluation_service_types.FulfillmentSpec()), + constants.Metric.FLUENCY: gapic_eval_service_types.FluencySpec(), + constants.Metric.COHERENCE: gapic_eval_service_types.CoherenceSpec(), + constants.Metric.SAFETY: gapic_eval_service_types.SafetySpec(), + constants.Metric.GROUNDEDNESS: (gapic_eval_service_types.GroundednessSpec()), + constants.Metric.FULFILLMENT: (gapic_eval_service_types.FulfillmentSpec()), constants.Metric.SUMMARIZATION_QUALITY: ( - gapic_evaluation_service_types.SummarizationQualitySpec() + gapic_eval_service_types.SummarizationQualitySpec() ), constants.Metric.SUMMARIZATION_HELPFULNESS: ( - gapic_evaluation_service_types.SummarizationHelpfulnessSpec() + gapic_eval_service_types.SummarizationHelpfulnessSpec() ), constants.Metric.SUMMARIZATION_VERBOSITY: ( - gapic_evaluation_service_types.SummarizationVerbositySpec() + gapic_eval_service_types.SummarizationVerbositySpec() ), constants.Metric.QUESTION_ANSWERING_QUALITY: ( - gapic_evaluation_service_types.QuestionAnsweringQualitySpec() + gapic_eval_service_types.QuestionAnsweringQualitySpec() ), constants.Metric.QUESTION_ANSWERING_RELEVANCE: ( - gapic_evaluation_service_types.QuestionAnsweringRelevanceSpec() + gapic_eval_service_types.QuestionAnsweringRelevanceSpec() ), constants.Metric.QUESTION_ANSWERING_CORRECTNESS: ( - gapic_evaluation_service_types.QuestionAnsweringCorrectnessSpec( - use_reference=True - ) + gapic_eval_service_types.QuestionAnsweringCorrectnessSpec(use_reference=True) ), constants.Metric.QUESTION_ANSWERING_HELPFULNESS: ( - gapic_evaluation_service_types.QuestionAnsweringHelpfulnessSpec() + gapic_eval_service_types.QuestionAnsweringHelpfulnessSpec() ), # Side-by-side(SxS) Pairwise Metrics. constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY: ( - gapic_evaluation_service_types.PairwiseSummarizationQualitySpec() + gapic_eval_service_types.PairwiseSummarizationQualitySpec() ), constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY: ( - gapic_evaluation_service_types.PairwiseQuestionAnsweringQualitySpec() + gapic_eval_service_types.PairwiseQuestionAnsweringQualitySpec() ), } def build_request( - metric_name: str, + metric: Union[str, metrics_base.PairwiseMetric], row_dict: Dict[str, Any], evaluation_run_config: eval_base.EvaluationRunConfig, -) -> gapic_evaluation_service_types.EvaluateInstancesRequest: +) -> gapic_eval_service_types.EvaluateInstancesRequest: """Builds a metric instance and form the request for the evaluation service. Args: @@ -130,9 +120,18 @@ def build_request( ) ) + if isinstance(metric, metrics_base.PairwiseMetric): + metric_name = metric.pairwise_metric_name + else: + metric_name = metric if metric_name not in _METRIC_NAME_TO_METRIC_SPEC: raise ValueError(f"Metric name: {metric_name} not supported.") + metric_spec = _METRIC_NAME_TO_METRIC_SPEC[metric_name] + if isinstance(metric, metrics_base.PairwiseMetric): + metric_spec.use_reference = metric.use_reference + metric_spec.version = metric.version + column_map = evaluation_run_config.column_map prediction = row_dict.get( column_map.get(constants.Dataset.MODEL_RESPONSE_COLUMN), "" @@ -144,32 +143,39 @@ def build_request( context = row_dict.get(column_map.get(constants.Dataset.CONTEXT_COLUMN), "") instruction = row_dict.get(column_map.get(constants.Dataset.INSTRUCTION_COLUMN), "") + if "use_reference" in json_format.MessageToDict( + metric_spec._pb, preserving_proto_field_name=True + ): + evaluation_run_config.validate_dataset_column( + constants.Dataset.REFERENCE_COLUMN + ) + # Automatic Metrics. if metric_name == constants.Metric.EXACT_MATCH: - instance = gapic_evaluation_service_types.ExactMatchInput( + instance = gapic_eval_service_types.ExactMatchInput( metric_spec=metric_spec, instances=[ - gapic_evaluation_service_types.ExactMatchInstance( + gapic_eval_service_types.ExactMatchInstance( prediction=prediction, reference=reference, ) ], ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, exact_match_input=instance, ) if metric_name == constants.Metric.BLEU: - instance = gapic_evaluation_service_types.BleuInput( + instance = gapic_eval_service_types.BleuInput( metric_spec=metric_spec, instances=[ - gapic_evaluation_service_types.BleuInstance( + gapic_eval_service_types.BleuInstance( prediction=prediction, reference=reference, ) ], ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, bleu_input=instance, ) @@ -179,128 +185,122 @@ def build_request( constants.Metric.ROUGE_L, constants.Metric.ROUGE_L_SUM, ): - instance = gapic_evaluation_service_types.RougeInput( + instance = gapic_eval_service_types.RougeInput( metric_spec=metric_spec, instances=[ - gapic_evaluation_service_types.RougeInstance( + gapic_eval_service_types.RougeInstance( prediction=prediction, reference=reference, ) ], ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, rouge_input=instance, ) if metric_name == constants.Metric.TOOL_CALL_VALID: - instance = gapic_evaluation_service_types.ToolCallValidInput( + instance = gapic_eval_service_types.ToolCallValidInput( metric_spec=metric_spec, instances=[ - gapic_evaluation_service_types.ToolCallValidInstance( + gapic_eval_service_types.ToolCallValidInstance( prediction=prediction, reference=reference, ) ], ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, tool_call_valid_input=instance, ) if metric_name == constants.Metric.TOOL_NAME_MATCH: - instance = gapic_evaluation_service_types.ToolNameMatchInput( + instance = gapic_eval_service_types.ToolNameMatchInput( metric_spec=metric_spec, instances=[ - gapic_evaluation_service_types.ToolNameMatchInstance( + gapic_eval_service_types.ToolNameMatchInstance( prediction=prediction, reference=reference, ) ], ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, tool_name_match_input=instance, ) if metric_name == constants.Metric.TOOL_PARAMETER_KEY_MATCH: - instance = gapic_evaluation_service_types.ToolParameterKeyMatchInput( + instance = gapic_eval_service_types.ToolParameterKeyMatchInput( metric_spec=metric_spec, instances=[ - gapic_evaluation_service_types.ToolParameterKeyMatchInstance( + gapic_eval_service_types.ToolParameterKeyMatchInstance( prediction=prediction, reference=reference, ) ], ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, tool_parameter_key_match_input=instance, ) if metric_name == constants.Metric.TOOL_PARAMETER_KV_MATCH: - instance = gapic_evaluation_service_types.ToolParameterKVMatchInput( + instance = gapic_eval_service_types.ToolParameterKVMatchInput( metric_spec=metric_spec, instances=[ - gapic_evaluation_service_types.ToolParameterKVMatchInstance( + gapic_eval_service_types.ToolParameterKVMatchInstance( prediction=prediction, reference=reference, ) ], ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, tool_parameter_kv_match_input=instance, ) # Model-based Pointwise Metrics. if metric_name == constants.Metric.COHERENCE: - coherence_input = gapic_evaluation_service_types.CoherenceInput( + coherence_input = gapic_eval_service_types.CoherenceInput( metric_spec=metric_spec, - instance=gapic_evaluation_service_types.CoherenceInstance( - prediction=prediction - ), + instance=gapic_eval_service_types.CoherenceInstance(prediction=prediction), ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, coherence_input=coherence_input, ) if metric_name == constants.Metric.FLUENCY: - fluency_input = gapic_evaluation_service_types.FluencyInput( + fluency_input = gapic_eval_service_types.FluencyInput( metric_spec=metric_spec, - instance=gapic_evaluation_service_types.FluencyInstance( - prediction=prediction - ), + instance=gapic_eval_service_types.FluencyInstance(prediction=prediction), ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, fluency_input=fluency_input, ) if metric_name == constants.Metric.SAFETY: - safety_input = gapic_evaluation_service_types.SafetyInput( + safety_input = gapic_eval_service_types.SafetyInput( metric_spec=metric_spec, - instance=gapic_evaluation_service_types.SafetyInstance( - prediction=prediction - ), + instance=gapic_eval_service_types.SafetyInstance(prediction=prediction), ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, safety_input=safety_input, ) if metric_name == constants.Metric.GROUNDEDNESS: - groundedness_input = gapic_evaluation_service_types.GroundednessInput( + groundedness_input = gapic_eval_service_types.GroundednessInput( metric_spec=metric_spec, - instance=gapic_evaluation_service_types.GroundednessInstance( + instance=gapic_eval_service_types.GroundednessInstance( prediction=prediction, context=context ), ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, groundedness_input=groundedness_input, ) if metric_name == constants.Metric.FULFILLMENT: - fulfillment_input = gapic_evaluation_service_types.FulfillmentInput( + fulfillment_input = gapic_eval_service_types.FulfillmentInput( metric_spec=metric_spec, - instance=gapic_evaluation_service_types.FulfillmentInstance( + instance=gapic_eval_service_types.FulfillmentInstance( prediction=prediction, instruction=instruction ), ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, fulfillment_input=fulfillment_input, ) @@ -309,95 +309,105 @@ def build_request( if metric_name == constants.Metric.SUMMARIZATION_QUALITY: # TODO(b/330807319): allow set reference field after setting metric spec is allowed. summarization_quality_input = ( - gapic_evaluation_service_types.SummarizationQualityInput( + gapic_eval_service_types.SummarizationQualityInput( metric_spec=metric_spec, - instance=gapic_evaluation_service_types.SummarizationQualityInstance( + instance=gapic_eval_service_types.SummarizationQualityInstance( prediction=prediction, context=context, instruction=instruction ), ) ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, summarization_quality_input=summarization_quality_input, ) if metric_name == constants.Metric.SUMMARIZATION_HELPFULNESS: # TODO(b/330807319): allow set reference field after setting metric spec is allowed. - summarization_helpfulness_input = gapic_evaluation_service_types.SummarizationHelpfulnessInput( - metric_spec=metric_spec, - instance=gapic_evaluation_service_types.SummarizationHelpfulnessInstance( - prediction=prediction, context=context, instruction=instruction - ), + summarization_helpfulness_input = ( + gapic_eval_service_types.SummarizationHelpfulnessInput( + metric_spec=metric_spec, + instance=gapic_eval_service_types.SummarizationHelpfulnessInstance( + prediction=prediction, context=context, instruction=instruction + ), + ) ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, summarization_helpfulness_input=summarization_helpfulness_input, ) if metric_name == constants.Metric.SUMMARIZATION_VERBOSITY: # TODO(b/330807319): allow set reference field after setting metric spec is allowed. summarization_verbosity_input = ( - gapic_evaluation_service_types.SummarizationVerbosityInput( + gapic_eval_service_types.SummarizationVerbosityInput( metric_spec=metric_spec, - instance=gapic_evaluation_service_types.SummarizationVerbosityInstance( + instance=gapic_eval_service_types.SummarizationVerbosityInstance( prediction=prediction, context=context, instruction=instruction ), ) ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, summarization_verbosity_input=summarization_verbosity_input, ) if metric_name == constants.Metric.QUESTION_ANSWERING_QUALITY: # TODO(b/330807319): allow set reference field after setting metric spec is allowed. - question_answering_quality_input = gapic_evaluation_service_types.QuestionAnsweringQualityInput( - metric_spec=metric_spec, - instance=gapic_evaluation_service_types.QuestionAnsweringQualityInstance( - prediction=prediction, context=context, instruction=instruction - ), + question_answering_quality_input = ( + gapic_eval_service_types.QuestionAnsweringQualityInput( + metric_spec=metric_spec, + instance=gapic_eval_service_types.QuestionAnsweringQualityInstance( + prediction=prediction, context=context, instruction=instruction + ), + ) ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, question_answering_quality_input=question_answering_quality_input, ) if metric_name == constants.Metric.QUESTION_ANSWERING_HELPFULNESS: # TODO(b/330807319): allow set reference field after setting metric spec is allowed. - question_answering_helpfulness_input = gapic_evaluation_service_types.QuestionAnsweringHelpfulnessInput( - metric_spec=metric_spec, - instance=gapic_evaluation_service_types.QuestionAnsweringHelpfulnessInstance( - prediction=prediction, - context=context, - instruction=instruction, - ), + question_answering_helpfulness_input = ( + gapic_eval_service_types.QuestionAnsweringHelpfulnessInput( + metric_spec=metric_spec, + instance=gapic_eval_service_types.QuestionAnsweringHelpfulnessInstance( + prediction=prediction, + context=context, + instruction=instruction, + ), + ) ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, question_answering_helpfulness_input=question_answering_helpfulness_input, ) if metric_name == constants.Metric.QUESTION_ANSWERING_RELEVANCE: # TODO(b/330807319): allow set reference field after setting metric spec is allowed. - question_answering_relevance_input = gapic_evaluation_service_types.QuestionAnsweringRelevanceInput( - metric_spec=metric_spec, - instance=gapic_evaluation_service_types.QuestionAnsweringRelevanceInstance( - prediction=prediction, - context=context, - instruction=instruction, - ), + question_answering_relevance_input = ( + gapic_eval_service_types.QuestionAnsweringRelevanceInput( + metric_spec=metric_spec, + instance=gapic_eval_service_types.QuestionAnsweringRelevanceInstance( + prediction=prediction, + context=context, + instruction=instruction, + ), + ) ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, question_answering_relevance_input=question_answering_relevance_input, ) if metric_name == constants.Metric.QUESTION_ANSWERING_CORRECTNESS: # TODO(b/330807319): allow set reference field after setting metric spec is allowed. - question_answering_correctness_input = gapic_evaluation_service_types.QuestionAnsweringCorrectnessInput( - metric_spec=metric_spec, - instance=gapic_evaluation_service_types.QuestionAnsweringCorrectnessInstance( - prediction=prediction, - context=context, - instruction=instruction, - reference=reference, - ), + question_answering_correctness_input = ( + gapic_eval_service_types.QuestionAnsweringCorrectnessInput( + metric_spec=metric_spec, + instance=gapic_eval_service_types.QuestionAnsweringCorrectnessInstance( + prediction=prediction, + context=context, + instruction=instruction, + reference=reference, + ), + ) ) - return gapic_evaluation_service_types.EvaluateInstancesRequest( + return gapic_eval_service_types.EvaluateInstancesRequest( location=location_path, question_answering_correctness_input=question_answering_correctness_input, ) @@ -405,10 +415,42 @@ def build_request( raise NotImplementedError("RAG context recall is not implemented.") # Side-by-side(SxS) Pairwise Metrics. if metric_name == constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY: - raise NotImplementedError("Pairwise summarization quality is not implemented.") + instance = gapic_eval_service_types.PairwiseSummarizationQualityInstance( + prediction=prediction, + baseline_prediction=baseline_prediction, + context=context, + instruction=instruction, + ) + if metric_spec.use_reference: + instance.reference = reference + pairwise_summarization_quality_input = ( + gapic_eval_service_types.PairwiseSummarizationQualityInput( + metric_spec=metric_spec, + instance=instance, + ) + ) + return gapic_eval_service_types.EvaluateInstancesRequest( + location=location_path, + pairwise_summarization_quality_input=pairwise_summarization_quality_input, + ) if metric_name == constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY: - raise NotImplementedError( - "Pairwise question answering quality is not implemented." + instance = gapic_eval_service_types.PairwiseQuestionAnsweringQualityInstance( + prediction=prediction, + baseline_prediction=baseline_prediction, + context=context, + instruction=instruction, + ) + if metric_spec.use_reference: + instance.reference = reference + pairwise_question_answering_quality_input = ( + gapic_eval_service_types.PairwiseQuestionAnsweringQualityInput( + metric_spec=metric_spec, + instance=instance, + ) + ) + return gapic_eval_service_types.EvaluateInstancesRequest( + location=location_path, + pairwise_question_answering_quality_input=pairwise_question_answering_quality_input, ) @@ -462,20 +504,16 @@ def _parse_pairwise_results( ) -> Dict[str, Any]: """Parses the pairwise metric results from the evaluation results. - s - - Args: - metric_result_dict: The metric results dictionary. + Args: + metric_result_dict: The metric results dictionary. - Returns: - A dictionary containing metric score, explanation, confidence of the - metric. + Returns: + A dictionary containing metric score, explanation, confidence of the + metric. """ return { - # TODO(b/330598854): handle pairwise choice. constants.MetricResult.PAIRWISE_CHOICE_KEY: metric_result_dict.get( constants.MetricResult.PAIRWISE_CHOICE_KEY, - gapic_evaluation_service_types.PairwiseChoice.PAIRWISE_CHOICE_UNSPECIFIED, ), constants.MetricResult.EXPLANATION_KEY: metric_result_dict.get( constants.MetricResult.EXPLANATION_KEY @@ -487,7 +525,7 @@ def _parse_pairwise_results( def _handle_response( - response: gapic_evaluation_service_types.EvaluateInstancesResponse, + response: gapic_eval_service_types.EvaluateInstancesResponse, ) -> Dict[str, Any]: """Handles the response from the evaluation service. @@ -570,7 +608,7 @@ def _handle_response( async def evaluate_instances_async( client: gapic_evaluation_services.EvaluationServiceAsyncClient, - request: gapic_evaluation_service_types.EvaluateInstancesRequest, + request: gapic_eval_service_types.EvaluateInstancesRequest, ): """Evaluates an instance asynchronously. From 195c77ed7320aea3ab5899427a922d606ed78997 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Mon, 6 May 2024 11:56:08 -0700 Subject: [PATCH 09/30] fix: Add MAX_TOKENS to the list of successful finish reasons for Rapid Evaluation SDK PiperOrigin-RevId: 631138372 --- vertexai/preview/evaluation/_evaluation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vertexai/preview/evaluation/_evaluation.py b/vertexai/preview/evaluation/_evaluation.py index 476f9663e1..c0e5464f8d 100644 --- a/vertexai/preview/evaluation/_evaluation.py +++ b/vertexai/preview/evaluation/_evaluation.py @@ -82,6 +82,7 @@ } _SUCCESSFUL_FINISH_REASONS = [ gapic_content_types.Candidate.FinishReason.STOP, + gapic_content_types.Candidate.FinishReason.MAX_TOKENS, # Many responses have this finish reason gapic_content_types.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, ] From 20b18668f15c448813aad4f58f2a4d470d6da2ec Mon Sep 17 00:00:00 2001 From: Jaycee Li Date: Mon, 6 May 2024 16:10:29 -0700 Subject: [PATCH 10/30] fix: AttributeError for TorchModelSerializer.deserialize in torch >=2.3.0 PiperOrigin-RevId: 631215839 --- .../serialization_engine/serializers.py | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/vertexai/preview/_workflow/serialization_engine/serializers.py b/vertexai/preview/_workflow/serialization_engine/serializers.py index e4dca7e3a5..cf49f1f2b8 100644 --- a/vertexai/preview/_workflow/serialization_engine/serializers.py +++ b/vertexai/preview/_workflow/serialization_engine/serializers.py @@ -144,6 +144,7 @@ def _is_valid_gcs_path(path: str) -> bool: def _load_torch_model(path: str, map_location: "torch.device") -> "torch.nn.Module": import torch + try: return torch.load(path, map_location=map_location) except Exception: @@ -434,7 +435,9 @@ class TorchModelSerializer(serializers_base.Serializer): serializers_base.SerializationMetadata(serializer="TorchModelSerializer") ) - def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) -> str: + def serialize( + self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs + ) -> str: """Serializes a torch.nn.Module to a gcs path. Args: @@ -450,6 +453,7 @@ def serialize(self, to_serialize: "torch.nn.Module", gcs_path: str, **kwargs) -> ValueError: if `gcs_path` is not a valid GCS uri. """ import torch + del kwargs if not _is_valid_gcs_path(gcs_path): raise ValueError(f"Invalid gcs path: {gcs_path}") @@ -500,11 +504,18 @@ def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.nn.Module": except ImportError as e: raise ImportError("torch is not installed.") from e - map_location = ( - torch._GLOBAL_DEVICE_CONTEXT.device - if torch._GLOBAL_DEVICE_CONTEXT - else None - ) + # Get the default device in the local torch environment. + # If `set_default_device` hasn't been called, _GLOBAL_DEVICE_CONTEXT + # should be None, then we set map_location to None as well. + map_location = None + # In torch 2.3.0, get_default_device is introduced + if hasattr(torch._GLOBAL_DEVICE_CONTEXT, "device_context") and hasattr( + torch, "get_default_device" + ): + map_location = torch.get_default_device() + # For older versions, we get default device from _GLOBAL_DEVICE_CONTEXT + elif hasattr(torch._GLOBAL_DEVICE_CONTEXT, "device"): + map_location = torch._GLOBAL_DEVICE_CONTEXT.device if serialized_gcs_path.startswith("gs://"): with tempfile.NamedTemporaryFile() as temp_file: @@ -731,7 +742,9 @@ class TorchDataLoaderSerializer(serializers_base.Serializer): serializers_base.SerializationMetadata(serializer="TorchDataLoaderSerializer") ) - def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path: str): + def _serialize_to_local( + self, to_serialize: "torch.utils.data.DataLoader", path: str + ): """Serializes a torch.utils.data.DataLoader to a local path. Args: @@ -778,6 +791,7 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path: # for default batch sampler we store batch_size, drop_last, and sampler object # but not batch sampler object. import torch + if isinstance(to_serialize.batch_sampler, torch.utils.data.BatchSampler): pass_through_args["batch_size"] = to_serialize.batch_size pass_through_args["drop_last"] = to_serialize.drop_last @@ -797,7 +811,9 @@ def _serialize_to_local(self, to_serialize: "torch.utils.data.DataLoader", path: with open(f"{path}/pass_through_args.json", "w") as f: json.dump(pass_through_args, f) - def serialize(self, to_serialize: "torch.utils.data.DataLoader", gcs_path: str, **kwargs) -> str: + def serialize( + self, to_serialize: "torch.utils.data.DataLoader", gcs_path: str, **kwargs + ) -> str: """Serializes a torch.utils.data.DataLoader to a gcs path. Args: @@ -883,7 +899,9 @@ def _deserialize_from_local(self, path: str) -> "torch.utils.data.DataLoader": return torch.utils.data.DataLoader(**kwargs) - def deserialize(self, serialized_gcs_path: str, **kwargs) -> "torch.utils.data.DataLoader": + def deserialize( + self, serialized_gcs_path: str, **kwargs + ) -> "torch.utils.data.DataLoader": """Deserialize a torch.utils.data.DataLoader given the gcs path. Args: From 88188d294fc2ec55ec0b05640dc791a1a3a88255 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Mon, 6 May 2024 18:32:27 -0700 Subject: [PATCH 11/30] feat: GenAI - Tuning - Supervised - Added support for the `adapter_size` parameter PiperOrigin-RevId: 631251312 --- vertexai/tuning/_supervised_tuning.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vertexai/tuning/_supervised_tuning.py b/vertexai/tuning/_supervised_tuning.py index 047b53cfa3..f5b3fd0e96 100644 --- a/vertexai/tuning/_supervised_tuning.py +++ b/vertexai/tuning/_supervised_tuning.py @@ -13,7 +13,7 @@ # limitations under the License. # -from typing import Optional, Union +from typing import Literal, Optional, Union from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types @@ -29,6 +29,7 @@ def train( tuned_model_display_name: Optional[str] = None, epochs: Optional[int] = None, learning_rate_multiplier: Optional[float] = None, + adapter_size: Optional[Literal[1, 4, 8, 16]] = None, ) -> "SupervisedTuningJob": """Tunes a model using supervised training. @@ -44,6 +45,7 @@ def train( be up to 128 characters long and can consist of any UTF-8 characters. epochs: Number of training epoches for this tuning job. learning_rate_multiplier: Learning rate multiplier for tuning. + adapter_size: Adapter size for tuning. Returns: A `TuningJob` object. @@ -54,6 +56,7 @@ def train( hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters( epoch_count=epochs, learning_rate_multiplier=learning_rate_multiplier, + adapter_size=adapter_size, ), ) From bae8429ae078c69574d86280ae6c784aaa9b13b5 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Mon, 6 May 2024 23:59:39 -0700 Subject: [PATCH 12/30] feat: LLM - Made the tuning location parameters truly optional PiperOrigin-RevId: 631312780 --- vertexai/language_models/_language_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index e5c20e880e..2e0aeb8fe6 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -283,6 +283,11 @@ def tune_model( ValueError: If the "tuned_model_location" value is not supported RuntimeError: If the model does not support tuning """ + if tuning_job_location is None: + tuning_job_location = aiplatform_initializer.global_config.location + if tuned_model_location is None: + tuned_model_location = aiplatform_initializer.global_config.location + tuning_parameters = {} if batch_size is not None: tuning_parameters["batch_size"] = batch_size @@ -623,7 +628,6 @@ def _tune_model_rlhf( Args: tuning_parameters: Tuning pipeline parameter values. - tuning_job_location: GCP location where the tuning job should be run. Returns: A `LanguageModelTuningJob` object that represents the tuning job. From e47d436f24cc718e378a28c4a80293778e8c183a Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 7 May 2024 11:14:04 -0700 Subject: [PATCH 13/30] feat: add FeatureGroup init/get PiperOrigin-RevId: 631486385 --- google/cloud/aiplatform/compat/__init__.py | 8 ++ .../aiplatform/compat/services/__init__.py | 8 ++ .../cloud/aiplatform/compat/types/__init__.py | 2 + google/cloud/aiplatform/utils/__init__.py | 19 ++++ .../unit/vertexai/feature_store_constants.py | 16 +++ tests/unit/vertexai/test_feature_group.py | 103 ++++++++++++++++++ tests/unit/vertexai/test_feature_view.py | 8 +- vertexai/resources/preview/__init__.py | 2 + .../preview/feature_store/__init__.py | 5 + .../preview/feature_store/feature_group.py | 79 ++++++++++++++ .../preview/feature_store/feature_view.py | 14 +-- .../resources/preview/feature_store/utils.py | 8 ++ 12 files changed, 261 insertions(+), 11 deletions(-) create mode 100644 tests/unit/vertexai/test_feature_group.py create mode 100644 vertexai/resources/preview/feature_store/feature_group.py diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index 43e9bff969..b7f48a3a06 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -36,6 +36,9 @@ services.feature_online_store_service_client = ( services.feature_online_store_service_client_v1beta1 ) + services.feature_registry_service_client = ( + services.feature_registry_service_client_v1beta1 + ) services.featurestore_online_serving_service_client = ( services.featurestore_online_serving_service_client_v1beta1 ) @@ -91,6 +94,7 @@ types.feature_online_store_admin_service = ( types.feature_online_store_admin_service_v1beta1 ) + types.feature_registry_service = types.feature_registry_service_v1beta1 types.feature_online_store_service = types.feature_online_store_service_v1beta1 types.feature_selector = types.feature_selector_v1beta1 types.feature_view = types.feature_view_v1beta1 @@ -157,6 +161,9 @@ services.feature_online_store_admin_service_client = ( services.feature_online_store_admin_service_client_v1 ) + services.feature_registry_service_client = ( + services.feature_registry_service_client_v1 + ) services.feature_online_store_service_client = ( services.feature_online_store_service_client_v1 ) @@ -208,6 +215,7 @@ types.feature_online_store_admin_service = ( types.feature_online_store_admin_service_v1 ) + types.feature_registry_service = types.feature_registry_service_v1 types.feature_online_store_service = types.feature_online_store_service_v1 types.feature_selector = types.feature_selector_v1 types.feature_view = types.feature_view_v1 diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py index 86bed88448..4d6c2aef9d 100644 --- a/google/cloud/aiplatform/compat/services/__init__.py +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -36,6 +36,9 @@ from google.cloud.aiplatform_v1beta1.services.feature_online_store_admin_service import ( client as feature_online_store_admin_service_client_v1beta1, ) +from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( + client as feature_registry_service_client_v1beta1, +) from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import ( client as featurestore_online_serving_service_client_v1beta1, ) @@ -119,6 +122,9 @@ from google.cloud.aiplatform_v1.services.feature_online_store_admin_service import ( client as feature_online_store_admin_service_client_v1, ) +from google.cloud.aiplatform_v1.services.feature_registry_service import ( + client as feature_registry_service_client_v1, +) from google.cloud.aiplatform_v1.services.featurestore_online_serving_service import ( client as featurestore_online_serving_service_client_v1, ) @@ -174,6 +180,7 @@ endpoint_service_client_v1, feature_online_store_service_client_v1, feature_online_store_admin_service_client_v1, + feature_registry_service_client_v1, featurestore_online_serving_service_client_v1, featurestore_service_client_v1, index_service_client_v1, @@ -196,6 +203,7 @@ endpoint_service_client_v1beta1, feature_online_store_service_client_v1beta1, feature_online_store_admin_service_client_v1beta1, + feature_registry_service_client_v1beta1, featurestore_online_serving_service_client_v1beta1, featurestore_service_client_v1beta1, index_service_client_v1beta1, diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index ad11da8a3f..470d72088d 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -47,6 +47,7 @@ feature_online_store as feature_online_store_v1beta1, feature_online_store_admin_service as feature_online_store_admin_service_v1beta1, feature_online_store_service as feature_online_store_service_v1beta1, + feature_registry_service as feature_registry_service_v1beta1, feature_selector as feature_selector_v1beta1, feature_view as feature_view_v1beta1, feature_view_sync as feature_view_sync_v1beta1, @@ -136,6 +137,7 @@ feature_online_store as feature_online_store_v1, feature_online_store_admin_service as feature_online_store_admin_service_v1, feature_online_store_service as feature_online_store_service_v1, + feature_registry_service as feature_registry_service_v1, feature_selector as feature_selector_v1, feature_view as feature_view_v1, feature_view_sync as feature_view_sync_v1, diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py index 1553001cbd..ea798e1101 100644 --- a/google/cloud/aiplatform/utils/__init__.py +++ b/google/cloud/aiplatform/utils/__init__.py @@ -43,6 +43,7 @@ extension_registry_service_client_v1beta1, feature_online_store_admin_service_client_v1beta1, feature_online_store_service_client_v1beta1, + feature_registry_service_client_v1beta1, featurestore_online_serving_service_client_v1beta1, featurestore_service_client_v1beta1, index_service_client_v1beta1, @@ -71,6 +72,7 @@ endpoint_service_client_v1, feature_online_store_admin_service_client_v1, feature_online_store_service_client_v1, + feature_registry_service_client_v1, featurestore_online_serving_service_client_v1, featurestore_service_client_v1, index_service_client_v1, @@ -100,6 +102,7 @@ endpoint_service_client_v1beta1.EndpointServiceClient, feature_online_store_admin_service_client_v1beta1.FeatureOnlineStoreAdminServiceClient, feature_online_store_service_client_v1beta1.FeatureOnlineStoreServiceClient, + feature_registry_service_client_v1beta1.FeatureRegistryServiceClient, featurestore_online_serving_service_client_v1beta1.FeaturestoreOnlineServingServiceClient, featurestore_service_client_v1beta1.FeaturestoreServiceClient, index_service_client_v1beta1.IndexServiceClient, @@ -120,6 +123,7 @@ endpoint_service_client_v1.EndpointServiceClient, feature_online_store_admin_service_client_v1.FeatureOnlineStoreAdminServiceClient, feature_online_store_service_client_v1.FeatureOnlineStoreServiceClient, + feature_registry_service_client_v1.FeatureRegistryServiceClient, featurestore_online_serving_service_client_v1.FeaturestoreOnlineServingServiceClient, featurestore_service_client_v1.FeaturestoreServiceClient, metadata_service_client_v1.MetadataServiceClient, @@ -635,6 +639,21 @@ class FeatureOnlineStoreClientWithOverride(ClientWithOverride): ) +class FeatureRegistryClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + ( + compat.V1, + feature_registry_service_client_v1.FeatureRegistryServiceClient, + ), + ( + compat.V1BETA1, + feature_registry_service_client_v1beta1.FeatureRegistryServiceClient, + ), + ) + + class FeaturestoreClientWithOverride(ClientWithOverride): _is_temporary = True _default_version = compat.DEFAULT_VERSION diff --git a/tests/unit/vertexai/feature_store_constants.py b/tests/unit/vertexai/feature_store_constants.py index f13fc1ab0d..0c04dbcdc5 100644 --- a/tests/unit/vertexai/feature_store_constants.py +++ b/tests/unit/vertexai/feature_store_constants.py @@ -262,3 +262,19 @@ ] ) ) + +_TEST_FG1_ID = "my_fg1" +_TEST_FG1_PATH = f"{_TEST_PARENT}/featureGroups/{_TEST_FG1_ID}" +_TEST_FG1_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table_for_fg1" +_TEST_FG1_ENTITY_ID_COLUMNS = ["entity_id"] +_TEST_FG1_LABELS = {"my_key": "my_fg1"} +_TEST_FG1 = types.feature_group.FeatureGroup( + name=_TEST_FG1_PATH, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG1_BQ_URI, + ), + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG1_LABELS, +) diff --git a/tests/unit/vertexai/test_feature_group.py b/tests/unit/vertexai/test_feature_group.py new file mode 100644 index 0000000000..3120818609 --- /dev/null +++ b/tests/unit/vertexai/test_feature_group.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, List +from unittest.mock import patch + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from vertexai.resources.preview import ( + FeatureGroup, +) +import vertexai.resources.preview.feature_store.utils as fs_utils +import pytest +from google.cloud.aiplatform.compat.services import ( + feature_registry_service_client, +) + + +from feature_store_constants import ( + _TEST_PROJECT, + _TEST_LOCATION, + _TEST_FG1, + _TEST_FG1_ID, + _TEST_FG1_PATH, + _TEST_FG1_BQ_URI, + _TEST_FG1_ENTITY_ID_COLUMNS, + _TEST_FG1_LABELS, +) + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +@pytest.fixture +def get_fg_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "get_feature_group", + ) as get_fg_mock: + get_fg_mock.return_value = _TEST_FG1 + yield get_fg_mock + + +def fg_eq( + fg_to_check: FeatureGroup, + name: str, + resource_name: str, + source_uri: str, + entity_id_columns: List[str], + project: str, + location: str, + labels: Dict[str, str], +): + """Check if a FeatureGroup has the appropriate values set.""" + assert fg_to_check.name == name + assert fg_to_check.resource_name == resource_name + assert fg_to_check.source == fs_utils.FeatureGroupBigQuerySource( + uri=source_uri, + entity_id_columns=entity_id_columns, + ) + assert fg_to_check.project == project + assert fg_to_check.location == location + assert fg_to_check.labels == labels + + +@pytest.mark.parametrize( + "feature_group_name", + [_TEST_FG1_ID, _TEST_FG1_PATH], +) +def test_init(feature_group_name, get_fg_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup(feature_group_name) + + get_fg_mock.assert_called_once_with( + name=_TEST_FG1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fg_eq( + fg, + name=_TEST_FG1_ID, + resource_name=_TEST_FG1_PATH, + source_uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG1_LABELS, + ) diff --git a/tests/unit/vertexai/test_feature_view.py b/tests/unit/vertexai/test_feature_view.py index a5b9b146dd..0d330c5fcd 100644 --- a/tests/unit/vertexai/test_feature_view.py +++ b/tests/unit/vertexai/test_feature_view.py @@ -26,7 +26,7 @@ from vertexai.resources.preview import ( FeatureView, ) -import vertexai.resources.preview.feature_store.utils as fv_utils +import vertexai.resources.preview.feature_store.utils as fs_utils import pytest from google.cloud.aiplatform.compat.services import ( feature_online_store_admin_service_client, @@ -418,7 +418,7 @@ def test_fetch_feature_values_optimized_no_endpoint( ): """Tests that the public endpoint is not created for the optimized online store.""" with pytest.raises( - fv_utils.PublicEndpointNotFoundError, + fs_utils.PublicEndpointNotFoundError, match=re.escape( "Public endpoint is not created yet for the optimized online " "store:my_esf_optimised_fos2. Please run sync and wait for it " @@ -498,8 +498,8 @@ def test_search_nearest_entities_no_endpoint( try: FeatureView(_TEST_OPTIMIZED_FV2_PATH).search(entity_id="key1").to_dict() assert not fetch_feature_values_mock.called - except fv_utils.PublicEndpointNotFoundError as e: - assert isinstance(e, fv_utils.PublicEndpointNotFoundError) + except fs_utils.PublicEndpointNotFoundError as e: + assert isinstance(e, fs_utils.PublicEndpointNotFoundError) error_msg = ( "Public endpoint is not created yet for the optimized online " "store:my_esf_optimised_fos2. Please run sync and wait for it " diff --git a/vertexai/resources/preview/__init__.py b/vertexai/resources/preview/__init__.py index 2dd1cc5f10..aef94c6099 100644 --- a/vertexai/resources/preview/__init__.py +++ b/vertexai/resources/preview/__init__.py @@ -35,6 +35,7 @@ ) from vertexai.resources.preview.feature_store import ( + FeatureGroup, FeatureOnlineStore, FeatureOnlineStoreType, FeatureView, @@ -62,6 +63,7 @@ "PersistentResource", "EntityType", "PipelineJobSchedule", + "FeatureGroup", "FeatureOnlineStoreType", "FeatureOnlineStore", "FeatureView", diff --git a/vertexai/resources/preview/feature_store/__init__.py b/vertexai/resources/preview/feature_store/__init__.py index 39478d9ae9..07855d5088 100644 --- a/vertexai/resources/preview/feature_store/__init__.py +++ b/vertexai/resources/preview/feature_store/__init__.py @@ -16,6 +16,10 @@ # """The vertexai resources preview module.""" +from vertexai.resources.preview.feature_store.feature_group import ( + FeatureGroup, +) + from vertexai.resources.preview.feature_store.feature_online_store import ( FeatureOnlineStore, FeatureOnlineStoreType, @@ -36,6 +40,7 @@ ) __all__ = ( + FeatureGroup, FeatureOnlineStoreType, FeatureOnlineStore, FeatureView, diff --git a/vertexai/resources/preview/feature_store/feature_group.py b/vertexai/resources/preview/feature_store/feature_group.py new file mode 100644 index 0000000000..be437389cd --- /dev/null +++ b/vertexai/resources/preview/feature_store/feature_group.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.types import ( + feature_group as gca_feature_group, +) +import vertexai.resources.preview.feature_store.utils as fs_utils + + +class FeatureGroup(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature Group resources.""" + + client_class = utils.FeatureRegistryClientWithOverride + + _resource_noun = "feature_groups" + _getter_method = "get_feature_group" + _list_method = "list_feature_groups" + _delete_method = "delete_feature_group" + _parse_resource_name_method = "parse_feature_group_path" + _format_resource_name_method = "feature_group_path" + _gca_resource: gca_feature_group.FeatureGroup + + def __init__( + self, + name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature group. + + Args: + name: + The resource name + (`projects/.../locations/.../featureGroups/...`) or ID. + project: + Project to retrieve feature group from. If unset, the + project set in aiplatform.init will be used. + location: + Location to retrieve feature group from. If not set, + location set in aiplatform.init will be used. + credentials: + Custom credentials to use to retrieve this feature group. + Overrides credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + self._gca_resource = self._get_gca_resource(resource_name=name) + + @property + def source(self) -> fs_utils.FeatureGroupBigQuerySource: + return fs_utils.FeatureGroupBigQuerySource( + uri=self._gca_resource.big_query.big_query_source.input_uri, + entity_id_columns=self._gca_resource.big_query.entity_id_columns, + ) diff --git a/vertexai/resources/preview/feature_store/feature_view.py b/vertexai/resources/preview/feature_store/feature_view.py index ac08c29e21..f21d09cda5 100644 --- a/vertexai/resources/preview/feature_store/feature_view.py +++ b/vertexai/resources/preview/feature_store/feature_view.py @@ -26,7 +26,7 @@ feature_view as gca_feature_view, feature_online_store_service as fos_service, ) -import vertexai.resources.preview.feature_store.utils as fv_utils +import vertexai.resources.preview.feature_store.utils as fs_utils _LOGGER = base.Logger(__name__) @@ -113,7 +113,7 @@ def _get_online_store_client(self) -> utils.FeatureOnlineStoreClientWithOverride if getattr(self, "_online_store_client", None): return self._online_store_client - fos_name = fv_utils.get_feature_online_store_name(self.resource_name) + fos_name = fs_utils.get_feature_online_store_name(self.resource_name) from .feature_online_store import FeatureOnlineStore fos = FeatureOnlineStore(name=fos_name) @@ -130,7 +130,7 @@ def _get_online_store_client(self) -> utils.FeatureOnlineStoreClientWithOverride # From here, optimized serving. if not fos._gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name: - raise fv_utils.PublicEndpointNotFoundError( + raise fs_utils.PublicEndpointNotFoundError( "Public endpoint is not created yet for the optimized online store:" f"{fos_name}. Please run sync and wait for it to complete." ) @@ -253,7 +253,7 @@ def read( self, key: List[str], request_timeout: Optional[float] = None, - ) -> fv_utils.FeatureViewReadResponse: + ) -> fs_utils.FeatureViewReadResponse: """Read the feature values from FeatureView. Example Usage: @@ -279,7 +279,7 @@ def read( ), timeout=request_timeout, ) - return fv_utils.FeatureViewReadResponse(response) + return fs_utils.FeatureViewReadResponse(response) def search( self, @@ -294,7 +294,7 @@ def search( approximate_neighbor_candidates: Optional[int] = None, leaf_nodes_search_fraction: Optional[float] = None, request_timeout: Optional[float] = None, - ) -> fv_utils.SearchNearestEntitiesResponse: + ) -> fs_utils.SearchNearestEntitiesResponse: """Search the nearest entities from FeatureView. Example Usage: @@ -361,7 +361,7 @@ def search( ), timeout=request_timeout, ) - return fv_utils.SearchNearestEntitiesResponse(response) + return fs_utils.SearchNearestEntitiesResponse(response) class FeatureViewSync(base.VertexAiResourceNounWithFutureManager): """Class for managing Feature View Sync resources.""" diff --git a/vertexai/resources/preview/feature_store/utils.py b/vertexai/resources/preview/feature_store/utils.py index 6c49903e7a..4f93be8b35 100644 --- a/vertexai/resources/preview/feature_store/utils.py +++ b/vertexai/resources/preview/feature_store/utils.py @@ -164,3 +164,11 @@ def as_dict(self) -> Dict[str, Any]: else: config["brute_force_config"] = self.algorithm_config.as_dict() return config + + +@dataclass +class FeatureGroupBigQuerySource: + """BigQuery source for the Feature Group.""" + + uri: str + entity_id_columns: List[str] From c03767ca9ee23bce2f9738a265f5025dc5bce024 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 7 May 2024 12:54:33 -0700 Subject: [PATCH 14/30] chore: Support both full resource name and resource id in Model Monitoring SDK. PiperOrigin-RevId: 631518248 --- .../preview/ml_monitoring/model_monitors.py | 138 ++++++++++++++---- 1 file changed, 111 insertions(+), 27 deletions(-) diff --git a/vertexai/resources/preview/ml_monitoring/model_monitors.py b/vertexai/resources/preview/ml_monitoring/model_monitors.py index dfe18f65f4..6e26858f08 100644 --- a/vertexai/resources/preview/ml_monitoring/model_monitors.py +++ b/vertexai/resources/preview/ml_monitoring/model_monitors.py @@ -18,6 +18,7 @@ import copy import dataclasses import json +import re from typing import Any, Dict, List, Optional from google.auth import credentials as auth_credentials @@ -237,6 +238,53 @@ def _transform_field_schema( return result +def _get_schedule_name( + schedule_name: str +) -> str: + if schedule_name: + client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + ) + if client.parse_schedule_path(schedule_name): + return schedule_name + elif re.match("^{}$".format("[0-9]{0,127}"), schedule_name): + return client.schedule_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + schedule=schedule_name, + ) + else: + raise ValueError( + "schedule name must be of the format `projects/{project}/locations/{location}/schedules/{schedule}` or `{schedule}`" + ) + return schedule_name + + +def _get_model_monitoring_job_name( + model_monitoring_job_name: str, + model_monitor_name: str, +) -> str: + if model_monitoring_job_name: + client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + ) + if client.parse_model_monitoring_job_path(model_monitoring_job_name): + return model_monitoring_job_name + elif re.match("^{}$".format("[0-9]{0,127}"), model_monitoring_job_name): + model_monitor_name = model_monitor_name.split("/")[-1] + return client.model_monitoring_job_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + model_monitor=model_monitor_name, + model_monitoring_job=model_monitoring_job_name, + ) + else: + raise ValueError( + "model monitoring job name must be of the format `projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}` or `{model_monitoring_job}`" + ) + return model_monitoring_job_name + + @dataclasses.dataclass class MetricsSearchResponse: """MetricsSearchResponse represents a response of the search metrics request. @@ -784,6 +832,8 @@ def update_schedule( schedule_name (str): Required. The resource name of schedule that needs to be updated. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` display_name (str): Optional. The user-defined name of the Schedule. The name can be up to 128 characters long and can be consist of @@ -833,8 +883,9 @@ def update_schedule( project=self.project, location=self.location, ) - - current_schedule = copy.deepcopy(self.get_schedule(schedule_name=schedule_name)) + schedule_name = _get_schedule_name(schedule_name) + current_schedule = copy.deepcopy( + self.get_schedule(schedule_name=schedule_name)) update_mask = [] if display_name is not None: update_mask.append("display_name") @@ -911,13 +962,17 @@ def delete_schedule(self, schedule_name: str) -> None: schedule_name (str): Required. The resource name of schedule that needs to be deleted. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` """ api_client = initializer.global_config.create_client( client_class=utils.ScheduleClientWithOverride, credentials=self.credentials, location_override=self.location, ) - api_client.select_version("v1beta1").delete_schedule(name=schedule_name) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").delete_schedule( + name=schedule_name) def pause_schedule(self, schedule_name: str) -> None: """Pauses an existing Schedule. @@ -926,13 +981,17 @@ def pause_schedule(self, schedule_name: str) -> None: schedule_name (str): Required. The resource name of schedule that needs to be paused. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` """ api_client = initializer.global_config.create_client( client_class=utils.ScheduleClientWithOverride, credentials=self.credentials, location_override=self.location, ) - api_client.select_version("v1beta1").pause_schedule(name=schedule_name) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").pause_schedule( + name=schedule_name) def resume_schedule(self, schedule_name: str) -> None: """Resumes an existing Schedule. @@ -941,13 +1000,17 @@ def resume_schedule(self, schedule_name: str) -> None: schedule_name (str): Required. The resource name of schedule that needs to be resumed. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` """ api_client = initializer.global_config.create_client( client_class=utils.ScheduleClientWithOverride, credentials=self.credentials, location_override=self.location, ) - api_client.select_version("v1beta1").resume_schedule(name=schedule_name) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").resume_schedule( + name=schedule_name) def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule": """Gets an existing Schedule. @@ -956,6 +1019,8 @@ def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule": schedule_name (str): Required. The resource name of schedule that needs to be fetched. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` Returns: Schedule: The schedule requested. @@ -965,6 +1030,7 @@ def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule": credentials=self.credentials, location_override=self.location, ) + schedule_name = _get_schedule_name(schedule_name) return api_client.select_version("v1beta1").get_schedule(name=schedule_name) def list_schedules( @@ -1375,13 +1441,17 @@ def delete_model_monitoring_job(self, model_monitoring_job_name: str) -> None: needs to be deleted. Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or + ``{model_monitoring_job}`` """ api_client = initializer.global_config.create_client( client_class=utils.ModelMonitoringClientWithOverride, credentials=self.credentials, location_override=self.location, ) - api_client.delete_model_monitoring_job(name=model_monitoring_job_name) + job_resource_name = _get_model_monitoring_job_name( + model_monitoring_job_name, self._gca_resource.name) + api_client.delete_model_monitoring_job(name=job_resource_name) def get_model_monitoring_job( self, model_monitoring_job_name: str @@ -1400,21 +1470,14 @@ def get_model_monitoring_job( ModelMonitoringJob: The model monitoring job get. """ self.wait() - if model_monitoring_job_name.startswith("projects/"): - return ModelMonitoringJob( - model_monitoring_job_name=model_monitoring_job_name, - project=self.project, - location=self.location, - credentials=self.credentials, - ) - else: - return ModelMonitoringJob( - model_monitoring_job_name=model_monitoring_job_name, - model_monitor_id=self._gca_resource.name, - project=self.project, - location=self.location, - credentials=self.credentials, - ) + job_resource_name = _get_model_monitoring_job_name( + model_monitoring_job_name, self._gca_resource.name) + return ModelMonitoringJob( + model_monitoring_job_name=job_resource_name, + project=self.project, + location=self.location, + credentials=self.credentials, + ) def show_feature_drift_stats(self, model_monitoring_job_name: str) -> None: """The method to visualize the feature drift result from a model monitoring job as a histogram chart and a table. @@ -1424,17 +1487,24 @@ def show_feature_drift_stats(self, model_monitoring_job_name: str) -> None: Required. The resource name of model monitoring job to show the drift stats from. Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or + ``{model_monitoring_job}`` """ api_client = initializer.global_config.create_client( client_class=utils.ModelMonitoringClientWithOverride, credentials=self.credentials, location_override=self.location, ) - job = api_client.get_model_monitoring_job(name=model_monitoring_job_name) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) output_directory = ( job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix ) - job_id = model_monitoring_job_name.split("/")[-1] target_output, baseline_output = _feature_drift_stats_output_path( output_directory, job_id ) @@ -1455,17 +1525,24 @@ def show_output_drift_stats(self, model_monitoring_job_name: str) -> None: Required. The resource name of model monitoring job to show the drift stats from. Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or + ``{model_monitoring_job}`` """ api_client = initializer.global_config.create_client( client_class=utils.ModelMonitoringClientWithOverride, credentials=self.credentials, location_override=self.location, ) - job = api_client.get_model_monitoring_job(name=model_monitoring_job_name) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) output_directory = ( job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix ) - job_id = model_monitoring_job_name.split("/")[-1] target_output, baseline_output = _prediction_output_stats_output_path( output_directory, job_id ) @@ -1486,17 +1563,24 @@ def show_feature_attribution_drift_stats( feature attribution drift stats from. Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or + ``{model_monitoring_job}`` """ api_client = initializer.global_config.create_client( client_class=utils.ModelMonitoringClientWithOverride, credentials=self.credentials, location_override=self.location, ) - job = api_client.get_model_monitoring_job(name=model_monitoring_job_name) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) output_directory = ( job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix ) - job_id = model_monitoring_job_name.split("/")[-1] target_stats_output = _feature_attribution_target_stats_output_path( output_directory, job_id ) From cd85d8f74d3922de3f871415bacf77c594f0c547 Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Tue, 7 May 2024 15:13:32 -0700 Subject: [PATCH 15/30] feat: Support vector_distance_threshold filtering and file-based retrieval for RAG PiperOrigin-RevId: 631564047 --- tests/unit/vertex_rag/test_rag_constants.py | 9 ++ tests/unit/vertex_rag/test_rag_retrieval.py | 42 +++++++-- tests/unit/vertex_rag/test_rag_store.py | 62 ++++++++++++++ tests/unit/vertex_ray/test_bigquery.py | 4 +- tests/unit/vertexai/test_generative_models.py | 11 ++- vertexai/preview/rag/__init__.py | 5 ++ vertexai/preview/rag/rag_data.py | 1 + vertexai/preview/rag/rag_retrieval.py | 75 +++++++++++++--- vertexai/preview/rag/rag_store.py | 85 ++++++++++++++++--- vertexai/preview/rag/utils/resources.py | 22 ++++- 10 files changed, 283 insertions(+), 33 deletions(-) create mode 100644 tests/unit/vertex_rag/test_rag_store.py diff --git a/tests/unit/vertex_rag/test_rag_constants.py b/tests/unit/vertex_rag/test_rag_constants.py index 7d0a05bdd3..847cefa31e 100644 --- a/tests/unit/vertex_rag/test_rag_constants.py +++ b/tests/unit/vertex_rag/test_rag_constants.py @@ -18,6 +18,7 @@ from vertexai.preview.rag.utils.resources import ( RagCorpus, RagFile, + RagResource, ) from google.cloud import aiplatform from google.cloud.aiplatform_v1beta1 import ( @@ -146,3 +147,11 @@ ] ) TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS) +TEST_RAG_RESOURCE = RagResource( + rag_corpus=TEST_RAG_CORPUS_RESOURCE_NAME, + rag_file_ids=[TEST_RAG_FILE_ID], +) +TEST_RAG_RESOURCE_INVALID_NAME = RagResource( + rag_corpus="213lkj-1/23jkl/", + rag_file_ids=[TEST_RAG_FILE_ID], +) diff --git a/tests/unit/vertex_rag/test_rag_retrieval.py b/tests/unit/vertex_rag/test_rag_retrieval.py index dc5080a97d..5aec0bd72a 100644 --- a/tests/unit/vertex_rag/test_rag_retrieval.py +++ b/tests/unit/vertex_rag/test_rag_retrieval.py @@ -70,11 +70,22 @@ def teardown_method(self): aiplatform.initializer.global_pool.shutdown(wait=True) @pytest.mark.usefixtures("retrieve_contexts_mock") - def test_retrieval_query_success(self): + def test_retrieval_query_rag_resources_success(self): response = rag.retrieval_query( - rag_corpora=[tc.TEST_RAG_CORPUS_RESOURCE_NAME], + rag_resources=[tc.TEST_RAG_RESOURCE], text=tc.TEST_QUERY_TEXT, similarity_top_k=2, + vector_distance_threshold=0.5, + ) + retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) + + @pytest.mark.usefixtures("retrieve_contexts_mock") + def test_retrieval_query_rag_corpora_success(self): + response = rag.retrieval_query( + rag_corpora=[tc.TEST_RAG_CORPUS_ID], + text=tc.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, ) retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) @@ -82,18 +93,39 @@ def test_retrieval_query_success(self): def test_retrieval_query_failure(self): with pytest.raises(RuntimeError) as e: rag.retrieval_query( - rag_corpora=[tc.TEST_RAG_CORPUS_RESOURCE_NAME], + rag_resources=[tc.TEST_RAG_RESOURCE], text=tc.TEST_QUERY_TEXT, similarity_top_k=2, + vector_distance_threshold=0.5, ) e.match("Failed in retrieving contexts due to") def test_retrieval_query_invalid_name(self): with pytest.raises(ValueError) as e: rag.retrieval_query( - # Should be RAG_CORPUS, not RAG_FILE - rag_corpora=[tc.TEST_RAG_FILE_RESOURCE_NAME], + rag_resources=[tc.TEST_RAG_RESOURCE_INVALID_NAME], text=tc.TEST_QUERY_TEXT, similarity_top_k=2, + vector_distance_threshold=0.5, ) e.match("Invalid RagCorpus name") + + def test_retrieval_query_multiple_rag_corpora(self): + with pytest.raises(ValueError) as e: + rag.retrieval_query( + rag_corpora=[tc.TEST_RAG_CORPUS_ID, tc.TEST_RAG_CORPUS_ID], + text=tc.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + e.match("Currently only support 1 RagCorpus") + + def test_retrieval_query_multiple_rag_resources(self): + with pytest.raises(ValueError) as e: + rag.retrieval_query( + rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE], + text=tc.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + e.match("Currently only support 1 RagResource") diff --git a/tests/unit/vertex_rag/test_rag_store.py b/tests/unit/vertex_rag/test_rag_store.py new file mode 100644 index 0000000000..1718892210 --- /dev/null +++ b/tests/unit/vertex_rag/test_rag_store.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from vertexai.preview import rag +from vertexai.preview.generative_models import Tool +import pytest +import test_rag_constants as tc + + +@pytest.mark.usefixtures("google_auth_mock") +class TestRagStoreValidations: + def test_retrieval_tool_invalid_name(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE_INVALID_NAME], + similarity_top_k=3, + vector_distance_threshold=0.4, + ), + ) + ) + e.match("Invalid RagCorpus name") + + def test_retrieval_tool_multiple_rag_corpora(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_corpora=[tc.TEST_RAG_CORPUS_ID, tc.TEST_RAG_CORPUS_ID], + similarity_top_k=3, + vector_distance_threshold=0.4, + ), + ) + ) + e.match("Currently only support 1 RagCorpus") + + def test_retrieval_tool_multiple_rag_resources(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE], + similarity_top_k=3, + vector_distance_threshold=0.4, + ), + ) + ) + e.match("Currently only support 1 RagResource") diff --git a/tests/unit/vertex_ray/test_bigquery.py b/tests/unit/vertex_ray/test_bigquery.py index eea8fe6ee7..3356af7f9b 100644 --- a/tests/unit/vertex_ray/test_bigquery.py +++ b/tests/unit/vertex_ray/test_bigquery.py @@ -311,7 +311,7 @@ def test_do_write_dataset_exists(self, ray_remote_function_mock): assert len(write_tasks_list) == 4 # Ray 2.9.3 only - def test_write(self, ray_get_mock): + def test_write(self, ray_get_mock, ray_remote_function_mock): if _BigQueryDatasink is None: return bq_datasink = _BigQueryDatasink( @@ -327,7 +327,7 @@ def test_write(self, ray_get_mock): assert status == "ok" # Ray 2.9.3 only - def test_write_dataset_exists(self, ray_get_mock): + def test_write_dataset_exists(self, ray_get_mock, ray_remote_function_mock): if _BigQueryDatasink is None: return bq_datasink = _BigQueryDatasink( diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index a6683aa70a..5aeb3d1386 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -908,13 +908,18 @@ def test_generate_content_grounding_vertex_ai_search_retriever(self): ) def test_generate_content_vertex_rag_retriever(self): model = preview_generative_models.GenerativeModel("gemini-pro") + rag_resources = [ + rag.RagResource( + rag_corpus=f"projects/{_TEST_PROJECT}/locations/us-central1/ragCorpora/1234556", + rag_file_ids=["123", "456"], + ), + ] rag_retriever_tool = preview_generative_models.Tool.from_retrieval( retrieval=rag.Retrieval( source=rag.VertexRagStore( - rag_corpora=[ - f"projects/{_TEST_PROJECT}/locations/us-central1/ragCorpora/1234556" - ], + rag_resources=rag_resources, similarity_top_k=1, + vector_distance_threshold=0.4, ), ), ) diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index 2058b7713d..56e6b54d64 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -36,6 +36,10 @@ Retrieval, VertexRagStore, ) +from vertexai.preview.rag.utils.resources import ( + RagResource, +) + __all__ = ( "create_corpus", @@ -51,4 +55,5 @@ "retrieval_query", "Retrieval", "VertexRagStore", + "RagResource", ) diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index 149ada8b1b..ae1984af0f 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +"""RAG data management SDK.""" from typing import Optional, Union, Sequence from google import auth diff --git a/vertexai/preview/rag/rag_retrieval.py b/vertexai/preview/rag/rag_retrieval.py index 32a650369e..519b15a822 100644 --- a/vertexai/preview/rag/rag_retrieval.py +++ b/vertexai/preview/rag/rag_retrieval.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +"""Retrieval query to get relevant contexts.""" + +import re from typing import List, Optional from google.cloud.aiplatform import initializer @@ -25,41 +28,89 @@ from vertexai.preview.rag.utils import ( _gapic_utils, ) +from vertexai.preview.rag.utils.resources import RagResource def retrieval_query( - rag_corpora: List[str], text: str, + rag_resources: Optional[List[RagResource]] = None, + rag_corpora: Optional[List[str]] = None, similarity_top_k: Optional[int] = 10, + vector_distance_threshold: Optional[float] = 0.3, ) -> RetrieveContextsResponse: """Retrieve top k relevant docs/chunks. + Example usage: + ``` + import vertexai + + vertexai.init(project="my-project") + + results = vertexai.preview.rag.retrieval_query( + text="Why is the sky blue?", + rag_resources=[vertexai.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + ``` + Args: - rag_corpora: A list of full resource name or corpus_id of the RagCorpus. Format: - ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}`` text: The query in text format to get relevant contexts. + rag_resources: A list of RagResource. It can be used to specify corpus + only or ragfiles. Currently only support one corpus or multiple files + from one corpus. In the future we may open up multiple corpora support. + rag_corpora: If rag_resources is not specified, use rag_corpora as a list + of rag corpora names. similarity_top_k: The number of contexts to retrieve. + vector_distance_threshold: Optional. Only return contexts with vector + distance smaller than the threshold. + Returns: RetrieveContextsResonse. """ parent = initializer.global_config.common_location_path() client = _gapic_utils.create_rag_service_client() - vertex_rag_store = RetrieveContextsRequest.VertexRagStore() - # Currently only support 1 RagCorpus. - if len(rag_corpora) > 1: - raise ValueError("Currently only support 1 RagCorpus.") - if len(rag_corpora[0].split("/")) == 6: - rag_corpus_name = rag_corpora[0] - elif len(rag_corpora[0].split("/")) == 1: - rag_corpus_name = parent + "/ragCorpora/" + rag_corpora[0] + + if rag_resources: + if len(rag_resources) > 1: + raise ValueError("Currently only support 1 RagResource.") + name = rag_resources[0].rag_corpus + elif rag_corpora: + if len(rag_corpora) > 1: + raise ValueError("Currently only support 1 RagCorpus.") + name = rag_corpora[0] + else: + raise ValueError("rag_resources or rag_corpora must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name): + rag_corpus_name = parent + "/ragCorpora/" + name else: raise ValueError( "Invalid RagCorpus name: %s. Proper format should be: projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}", rag_corpora, ) - vertex_rag_store.rag_corpora = [rag_corpus_name] + if rag_resources: + gapic_rag_resource = RetrieveContextsRequest.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, + ) + vertex_rag_store = RetrieveContextsRequest.VertexRagStore( + rag_resources=[gapic_rag_resource], + ) + else: + vertex_rag_store = RetrieveContextsRequest.VertexRagStore( + rag_corpora=[rag_corpus_name], + ) + + vertex_rag_store.vector_distance_threshold = vector_distance_threshold query = RagQuery(text=text, similarity_top_k=similarity_top_k) request = RetrieveContextsRequest( vertex_rag_store=vertex_rag_store, diff --git a/vertexai/preview/rag/rag_store.py b/vertexai/preview/rag/rag_store.py index 363e33d4b0..ff51901255 100644 --- a/vertexai/preview/rag/rag_store.py +++ b/vertexai/preview/rag/rag_store.py @@ -14,10 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Tentative module in private_preview.""" +"""RAG retrieval tool for content generation.""" -from typing import List,Optional, Union +import re +from typing import List, Optional, Union from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types +from google.cloud.aiplatform import initializer +from vertexai.preview.rag.utils import _gapic_utils +from vertexai.preview.rag.utils.resources import RagResource class Retrieval: @@ -39,18 +43,79 @@ class VertexRagStore: def __init__( self, - rag_corpora: List[str], - similarity_top_k: Optional[int], + rag_resources: Optional[List[RagResource]] = None, + rag_corpora: Optional[List[str]] = None, + similarity_top_k: Optional[int] = 10, + vector_distance_threshold: Optional[float] = 0.3, ): """Initializes a Vertex RAG store tool. + Example usage: + ``` + import vertexai + + vertexai.init(project="my-project") + + results = vertexai.preview.rag.retrieval_query( + text="Why is the sky blue?", + rag_resources=[vertexai.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + ``` + Args: - rag_corpora: A list of Vertex Rag Corpora resource name. Format: - projects/<>/locations/<>/ragCorpora/<>. + rag_resources: List of RagResource to retrieve from. It can be used + to specify corpus only or ragfiles. Currently only support one + corpus or multiple files from one corpus. In the future we + may open up multiple corpora support. + rag_corpora: If rag_resources is not specified, use rag_corpora as a + list of rag corpora names. similarity_top_k: Number of top k results to return from the selected corpora. + vector_distance_threshold (float): + Optional. Only return results with vector distance smaller than the threshold. + """ - self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( - rag_corpora=rag_corpora, - similarity_top_k=similarity_top_k, - ) + + if rag_resources: + if len(rag_resources) > 1: + raise ValueError("Currently only support 1 RagResource.") + name = rag_resources[0].rag_corpus + elif rag_corpora: + if len(rag_corpora) > 1: + raise ValueError("Currently only support 1 RagCorpus.") + name = rag_corpora[0] + else: + raise ValueError("rag_resources or rag_corpora must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name): + parent = initializer.global_config.common_location_path() + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + "Invalid RagCorpus name: %s. Proper format should be: projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}", + rag_corpora, + ) + if rag_resources: + gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, + ) + self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( + rag_resources=[gapic_rag_resource], + similarity_top_k=similarity_top_k, + vector_distance_threshold=vector_distance_threshold, + ) + else: + self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( + rag_corpora=[rag_corpus_name], + similarity_top_k=similarity_top_k, + vector_distance_threshold=vector_distance_threshold, + ) diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index c3e1001153..d64eecb5fe 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -16,7 +16,7 @@ # import dataclasses -from typing import Optional +from typing import List, Optional @dataclasses.dataclass @@ -49,3 +49,23 @@ class RagCorpus: name: Optional[str] = None display_name: Optional[str] = None description: Optional[str] = None + + +@dataclasses.dataclass +class RagResource: + """RagResource. + + The representation of the rag source. It can be used to specify corpus only + or ragfiles. Currently only support one corpus or multiple files from one + corpus. In the future we may open up multiple corpora support. + + Attributes: + rag_corpus: A Rag corpus resource name or corpus id. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}`` + or ``{rag_corpus_id}``. + rag_files_id: List of Rag file resource name or file ids in the same corpus. Format: + ``{rag_file}``. + """ + + rag_corpus: Optional[str] = None + rag_file_ids: Optional[List[str]] = None From 393810728b6b940e4cc8e1ac7f55875e3b750beb Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 8 May 2024 12:02:05 -0700 Subject: [PATCH 16/30] feat: add FeatureGroup create function PiperOrigin-RevId: 631879143 --- tests/unit/vertexai/test_feature_group.py | 138 +++++++++++++++- vertexai/resources/preview/__init__.py | 1 + .../preview/feature_store/__init__.py | 2 + .../preview/feature_store/feature_group.py | 147 +++++++++++++++++- 4 files changed, 280 insertions(+), 8 deletions(-) diff --git a/tests/unit/vertexai/test_feature_group.py b/tests/unit/vertexai/test_feature_group.py index 3120818609..ba986c3ef1 100644 --- a/tests/unit/vertexai/test_feature_group.py +++ b/tests/unit/vertexai/test_feature_group.py @@ -15,22 +15,32 @@ # limitations under the License. # +import re from typing import Dict, List -from unittest.mock import patch +from unittest import mock +from unittest.mock import call, patch +from google.api_core import operation as ga_operation from google.cloud import aiplatform from google.cloud.aiplatform import base +from vertexai.resources.preview.feature_store import ( + feature_group, +) from vertexai.resources.preview import ( FeatureGroup, ) -import vertexai.resources.preview.feature_store.utils as fs_utils +from vertexai.resources.preview.feature_store import ( + FeatureGroupBigQuerySource, +) import pytest from google.cloud.aiplatform.compat.services import ( feature_registry_service_client, ) +from google.cloud.aiplatform.compat import types from feature_store_constants import ( + _TEST_PARENT, _TEST_PROJECT, _TEST_LOCATION, _TEST_FG1, @@ -45,6 +55,16 @@ pytestmark = pytest.mark.usefixtures("google_auth_mock") +@pytest.fixture +def fg_logger_mock(): + with patch.object( + feature_group._LOGGER, + "info", + wraps=feature_group._LOGGER.info, + ) as logger_mock: + yield logger_mock + + @pytest.fixture def get_fg_mock(): with patch.object( @@ -55,6 +75,18 @@ def get_fg_mock(): yield get_fg_mock +@pytest.fixture +def create_fg_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "create_feature_group", + ) as create_fg_mock: + create_fg_lro_mock = mock.Mock(ga_operation.Operation) + create_fg_lro_mock.result.return_value = _TEST_FG1 + create_fg_mock.return_value = create_fg_lro_mock + yield create_fg_mock + + def fg_eq( fg_to_check: FeatureGroup, name: str, @@ -68,7 +100,7 @@ def fg_eq( """Check if a FeatureGroup has the appropriate values set.""" assert fg_to_check.name == name assert fg_to_check.resource_name == resource_name - assert fg_to_check.source == fs_utils.FeatureGroupBigQuerySource( + assert fg_to_check.source == FeatureGroupBigQuerySource( uri=source_uri, entity_id_columns=entity_id_columns, ) @@ -101,3 +133,103 @@ def test_init(feature_group_name, get_fg_mock): location=_TEST_LOCATION, labels=_TEST_FG1_LABELS, ) + + +def test_create_fg_no_source_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Please specify a valid source."), + ): + FeatureGroup.create("fg") + + +def test_create_fg_bad_source_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Only FeatureGroupBigQuerySource is a supported source."), + ): + FeatureGroup.create("fg", source=int(1)) + + +def test_create_fg_no_source_bq_uri_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Please specify URI in BigQuery source."), + ): + FeatureGroup.create( + "fg", source=FeatureGroupBigQuerySource(uri=None, entity_id_columns=None) + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_fg( + create_fg_mock, get_fg_mock, fg_logger_mock, create_request_timeout, sync +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup.create( + _TEST_FG1_ID, + source=FeatureGroupBigQuerySource( + uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG1_LABELS, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + fg.wait() + + # When creating, the FeatureOnlineStore object doesn't have the path set. + expected_fg = types.feature_group.FeatureGroup( + name=_TEST_FG1_ID, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG1_BQ_URI, + ), + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG1_LABELS, + ) + create_fg_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_group=expected_fg, + feature_group_id=_TEST_FG1_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fg_logger_mock.assert_has_calls( + [ + call("Creating FeatureGroup"), + call( + f"Create FeatureGroup backing LRO: {create_fg_mock.return_value.operation.name}" + ), + call( + "FeatureGroup created. Resource name: projects/test-project/locations/us-central1/featureGroups/my_fg1" + ), + call("To use this FeatureGroup in another session:"), + call( + "feature_group = aiplatform.FeatureGroup('projects/test-project/locations/us-central1/featureGroups/my_fg1')" + ), + ] + ) + + fg_eq( + fg, + name=_TEST_FG1_ID, + resource_name=_TEST_FG1_PATH, + source_uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG1_LABELS, + ) diff --git a/vertexai/resources/preview/__init__.py b/vertexai/resources/preview/__init__.py index aef94c6099..e0ea4632c8 100644 --- a/vertexai/resources/preview/__init__.py +++ b/vertexai/resources/preview/__init__.py @@ -64,6 +64,7 @@ "EntityType", "PipelineJobSchedule", "FeatureGroup", + "FeatureGroupBigQuerySource", "FeatureOnlineStoreType", "FeatureOnlineStore", "FeatureView", diff --git a/vertexai/resources/preview/feature_store/__init__.py b/vertexai/resources/preview/feature_store/__init__.py index 07855d5088..bc7d1b0373 100644 --- a/vertexai/resources/preview/feature_store/__init__.py +++ b/vertexai/resources/preview/feature_store/__init__.py @@ -30,6 +30,7 @@ ) from vertexai.resources.preview.feature_store.utils import ( + FeatureGroupBigQuerySource, FeatureViewBigQuerySource, FeatureViewReadResponse, IndexConfig, @@ -41,6 +42,7 @@ __all__ = ( FeatureGroup, + FeatureGroupBigQuerySource, FeatureOnlineStoreType, FeatureOnlineStore, FeatureView, diff --git a/vertexai/resources/preview/feature_store/feature_group.py b/vertexai/resources/preview/feature_store/feature_group.py index be437389cd..90198a4e01 100644 --- a/vertexai/resources/preview/feature_store/feature_group.py +++ b/vertexai/resources/preview/feature_store/feature_group.py @@ -15,14 +15,26 @@ # limitations under the License. # -from typing import Optional +from typing import ( + Sequence, + Tuple, + Dict, + List, + Optional, +) from google.auth import credentials as auth_credentials -from google.cloud.aiplatform import base +from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.types import ( feature_group as gca_feature_group, + io as gca_io, +) +from vertexai.resources.preview.feature_store.utils import ( + FeatureGroupBigQuerySource, ) -import vertexai.resources.preview.feature_store.utils as fs_utils + + +_LOGGER = base.Logger(__name__) class FeatureGroup(base.VertexAiResourceNounWithFutureManager): @@ -71,9 +83,134 @@ def __init__( self._gca_resource = self._get_gca_resource(resource_name=name) + @classmethod + def create( + cls, + name: str, + source: FeatureGroupBigQuerySource = None, + entity_id_columns: Optional[List[str]] = None, + labels: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> "FeatureGroup": + """Creates a new feature group. + + Args: + name: The name of the feature group. + source: The BigQuery source of the feature group. + entity_id_columns: + The entity ID columns. If not specified, defaults to + ['entity_id']. + labels: + The labels with user-defined metadata to organize your + FeatureGroup. + + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + + See https://goo.gl/xmQnxf for more information + on and examples of labels. No more than 64 user + labels can be associated with one + FeatureGroup(System labels are excluded)." + System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + description: Description of the FeatureGroup. + project: + Project to create feature group in. If unset, the project set in + aiplatform.init will be used. + location: + Location to create feature group in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature group. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureGroup - the FeatureGroup resource object. + """ + + if not source: + raise ValueError("Please specify a valid source.") + + # Only BigQuery source is supported right now. + if not isinstance(source, FeatureGroupBigQuerySource): + raise ValueError("Only FeatureGroupBigQuerySource is a supported source.") + + # BigQuery source validation. + if not source.uri: + raise ValueError("Please specify URI in BigQuery source.") + + if not source.entity_id_columns: + _LOGGER.info( + "No entity ID columns specified in BigQuery source. Defaulting to ['entity_id']." + ) + entity_id_columns = ["entity_id"] + else: + entity_id_columns = source.entity_id_columns + + gapic_feature_group = gca_feature_group.FeatureGroup( + big_query=gca_feature_group.FeatureGroup.BigQuery( + big_query_source=gca_io.BigQuerySource(input_uri=source.uri), + entity_id_columns=entity_id_columns, + ), + name=name, + description=description, + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_group.labels = labels + + if request_metadata is None: + request_metadata = () + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + create_feature_group_lro = api_client.create_feature_group( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + feature_group=gapic_feature_group, + feature_group_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(cls, create_feature_group_lro) + + created_feature_group = create_feature_group_lro.result() + + _LOGGER.log_create_complete(cls, created_feature_group, "feature_group") + + feature_group_obj = cls( + name=created_feature_group.name, + project=project, + location=location, + credentials=credentials, + ) + + return feature_group_obj + @property - def source(self) -> fs_utils.FeatureGroupBigQuerySource: - return fs_utils.FeatureGroupBigQuerySource( + def source(self) -> FeatureGroupBigQuerySource: + return FeatureGroupBigQuerySource( uri=self._gca_resource.big_query.big_query_source.input_uri, entity_id_columns=self._gca_resource.big_query.entity_id_columns, ) From 7fea7547084277dc974cbacc517ca1e95629a034 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 8 May 2024 12:27:08 -0700 Subject: [PATCH 17/30] feat: LLM - Support tuning of new text embedding models by migrating to the new v1.1.3 pipeline. PiperOrigin-RevId: 631887159 --- tests/unit/aiplatform/test_language_models.py | 117 ++++++++++++++---- .../_model_garden/_model_garden_models.py | 6 +- vertexai/language_models/_language_models.py | 39 +++--- 3 files changed, 115 insertions(+), 47 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 067e1e8318..b6f63f3490 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -563,7 +563,7 @@ def reverse_string_2(s):""", "parameterType": "STRING", }, "base_model_version_id": { - "defaultValue": "textembedding-gecko@001", + "defaultValue": "text-embedding-004", "description": "which base model to tune. This may be any stable\nnumbered version, for example `textembedding-gecko@001`.", "isOptional": True, "parameterType": "STRING", @@ -578,17 +578,15 @@ def reverse_string_2(s):""", "description": "the GCS path to the corpus data location.", "parameterType": "STRING", }, - "iterations": { - "defaultValue": 1000, - "description": "the number of steps to perform fine-tuning.", + "encryption_spec_key_name": { + "defaultValue": "", "isOptional": True, - "parameterType": "NUMBER_INTEGER", + "parameterType": "STRING", }, - "location": { - "defaultValue": "us-central1", - "description": "GCP region to run the pipeline.", + "learning_rate_multiplier": { + "defaultValue": 1.0, "isOptional": True, - "parameterType": "STRING", + "parameterType": "NUMBER_DOUBLE", }, "machine_type": { "defaultValue": "n1-standard-16", @@ -602,9 +600,10 @@ def reverse_string_2(s):""", "isOptional": True, "parameterType": "STRING", }, - "project": { - "description": "user's project id.", - "parameterType": "STRING", + "output_dimensionality": { + "defaultValue": -1, + "isOptional": True, + "parameterType": "NUMBER_INTEGER", }, "queries_path": { "description": "the GCS path to the queries location.", @@ -626,6 +625,12 @@ def reverse_string_2(s):""", "description": "the GCS path to the train label data location.", "parameterType": "STRING", }, + "train_steps": { + "defaultValue": 1000, + "description": "the number of steps to perform fine-tuning.", + "isOptional": True, + "parameterType": "NUMBER_INTEGER", + }, "validation_label_path": { "defaultValue": "", "description": "The GCS path to the validation label data location.", @@ -2283,6 +2288,61 @@ def test_text_generation_response_repr(self): ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"], indirect=True, ) + @pytest.mark.parametrize( + "base_model_version_id,tune_args,expected_pipeline_args", + [ # Do not pass any optional parameters. + ( + "textembedding-gecko@003", + dict( + training_data="gs://bucket/training.tsv", + corpus_data="gs://bucket/corpus.jsonl", + queries_data="gs://bucket/queries.jsonl", + ), + dict( + base_model_version_id="textembedding-gecko@003", + train_label_path="gs://bucket/training.tsv", + corpus_path="gs://bucket/corpus.jsonl", + queries_path="gs://bucket/queries.jsonl", + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ), + ), + # Pass all optional parameters. + ( + "text-multilingual-embedding-002", + dict( + training_data="gs://bucket/training.tsv", + corpus_data="gs://bucket/corpus.jsonl", + queries_data="gs://bucket/queries.jsonl", + test_data="gs://bucket/test.tsv", + validation_data="gs://bucket/validation.tsv", + tuned_model_location="us-central1", + model_display_name="my-tuned-model", + train_steps=30, + batch_size=256, + accelerator="NVIDIA_TESLA_V100", + accelerator_count=1, + machine_type="n1-highmem-16", + task_type="DEFAULT", + ), + dict( + train_steps=30, + accelerator_type="NVIDIA_TESLA_V100", + accelerator_count=1, + machine_type="n1-highmem-16", + base_model_version_id="text-multilingual-embedding-002", + train_label_path="gs://bucket/training.tsv", + corpus_path="gs://bucket/corpus.jsonl", + queries_path="gs://bucket/queries.jsonl", + test_label_path="gs://bucket/test.tsv", + batch_size=256, + model_display_name="my-tuned-model", + validation_label_path="gs://bucket/validation.tsv", + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + task_type="DEFAULT", + ), + ), + ], + ) def test_tune_text_embedding_model( self, mock_pipeline_service_create, @@ -2294,6 +2354,9 @@ def test_tune_text_embedding_model( mock_gcs_upload, mock_request_urlopen_gecko, mock_deploy_tuned_embedding_model, + tune_args, + expected_pipeline_args, + base_model_version_id, ): """Tests tuning the text embedding model.""" aiplatform.init( @@ -2309,23 +2372,23 @@ def test_tune_text_embedding_model( ), ): model = language_models.TextEmbeddingModel.from_pretrained( - "textembedding-gecko@003" - ) - tuning_job = model.tune_model( - training_data="gs://bucket/training.tsv", - corpus_data="gs://bucket/corpus.jsonl", - queries_data="gs://bucket/queries.jsonl", - test_data="gs://bucket/test.tsv", - tuned_model_location="us-central1", - train_steps=10, - accelerator="NVIDIA_TESLA_A100", + base_model_version_id ) + tuning_job = model.tune_model(**tune_args) call_kwargs = mock_pipeline_service_create.call_args[1] - pipeline_arguments = call_kwargs[ - "pipeline_job" - ].runtime_config.parameter_values - assert pipeline_arguments["iterations"] == 10 - assert pipeline_arguments["accelerator_type"] == "NVIDIA_TESLA_A100" + pipeline_arguments = dict( + call_kwargs["pipeline_job"].runtime_config.parameter_values + ) + + if ( + "model_display_name" not in tune_args + and "model_display_name" in pipeline_arguments + ): + # This is automatically generated from some params, so don't + # check it. + del pipeline_arguments["model_display_name"] + + assert pipeline_arguments == expected_pipeline_args # Testing the tuned model tuned_model = tuning_job.deploy_tuned_model() diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index e94d8a6e58..febde9d25d 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -39,8 +39,10 @@ "chat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", "codechat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", "codechat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", - "textembedding-gecko": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2", - "textembedding-gecko-multilingual": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2", + "textembedding-gecko": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3", + "textembedding-gecko-multilingual": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3", + "text-embedding-004": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3", + "text-multilingual-embedding-002": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3", } _LOGGER = base.Logger(__name__) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 2e0aeb8fe6..c78d93cd49 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -414,20 +414,24 @@ def _tune_model( model_id=self._model_id, schema_to_class_map={self._INSTANCE_SCHEMA_URI: type(self)}, ) - if model_info.tuning_pipeline_uri.startswith( - "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model" - ): - train_steps = tuning_parameters.pop("train_steps", None) - if train_steps: - tuning_parameters["iterations"] = train_steps + if _is_text_embedding_tuning_pipeline(model_info.tuning_pipeline_uri): tunable_base_model_id = self._model_id.rpartition("/")[-1] tuning_parameters["base_model_version_id"] = tunable_base_model_id else: tuning_parameters["large_model_reference"] = model_info.tuning_model_id - if aiplatform_initializer.global_config.encryption_spec_key_name: - tuning_parameters[ - "encryption_spec_key_name" - ] = aiplatform_initializer.global_config.encryption_spec_key_name + tuning_parameters.update( + { + "project": aiplatform_initializer.global_config.project, + # TODO(b/275444096): Remove the explicit location once tuning + # can happen in all regions. + # "location": aiplatform_initializer.global_config.location, + "location": tuned_model_location, + } + ) + if aiplatform_initializer.global_config.encryption_spec_key_name: + tuning_parameters[ + "encryption_spec_key_name" + ] = aiplatform_initializer.global_config.encryption_spec_key_name if not model_info.tuning_pipeline_uri: raise RuntimeError(f"The {self._model_id} model does not support tuning") @@ -3890,6 +3894,12 @@ def _maybe_upload_training_data( ) +def _is_text_embedding_tuning_pipeline(pipeline_uri: str) -> bool: + return pipeline_uri.startswith( + "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model" + ) + + def _launch_tuning_job( training_data: Union[str, "pandas.core.frame.DataFrame"], model_id: str, @@ -3931,16 +3941,9 @@ def _launch_tuning_job( model_display_name = name[:max_display_name_length] pipeline_arguments = { - "project": aiplatform_initializer.global_config.project, - # TODO(b/275444096): Remove the explicit location once tuning can happen in all regions - # "location": aiplatform_initializer.global_config.location, - "location": tuned_model_location, "model_display_name": model_display_name, } - - if tuning_pipeline_uri.startswith( - "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model" - ): + if _is_text_embedding_tuning_pipeline(tuning_pipeline_uri): pipeline_arguments["train_label_path"] = training_data_path elif training_data_path.startswith("gs://"): pipeline_arguments["dataset_uri"] = training_data_path From 378c68a5bcde3d012c7f725d453e73cab70cc9ff Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 8 May 2024 14:35:09 -0700 Subject: [PATCH 18/30] chore: add FeatureGroup list tests PiperOrigin-RevId: 631927820 --- .../unit/vertexai/feature_store_constants.py | 36 +++++++++++ tests/unit/vertexai/test_feature_group.py | 60 +++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/tests/unit/vertexai/feature_store_constants.py b/tests/unit/vertexai/feature_store_constants.py index 0c04dbcdc5..b6fd378bb6 100644 --- a/tests/unit/vertexai/feature_store_constants.py +++ b/tests/unit/vertexai/feature_store_constants.py @@ -278,3 +278,39 @@ ), labels=_TEST_FG1_LABELS, ) + + +_TEST_FG2_ID = "my_fg2" +_TEST_FG2_PATH = f"{_TEST_PARENT}/featureGroups/{_TEST_FG2_ID}" +_TEST_FG2_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table_for_fg2" +_TEST_FG2_ENTITY_ID_COLUMNS = ["entity_id1", "entity_id2"] +_TEST_FG2_LABELS = {"my_key2": "my_fg2"} +_TEST_FG2 = types.feature_group.FeatureGroup( + name=_TEST_FG2_PATH, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG2_BQ_URI, + ), + entity_id_columns=_TEST_FG2_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG2_LABELS, +) + + +_TEST_FG3_ID = "my_fg3" +_TEST_FG3_PATH = f"{_TEST_PARENT}/featureGroups/{_TEST_FG3_ID}" +_TEST_FG3_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table_for_fg3" +_TEST_FG3_ENTITY_ID_COLUMNS = ["entity_id1", "entity_id2", "entity_id3"] +_TEST_FG3_LABELS = {"my_key3": "my_fg3"} +_TEST_FG3 = types.feature_group.FeatureGroup( + name=_TEST_FG3_PATH, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG3_BQ_URI, + ), + entity_id_columns=_TEST_FG3_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG3_LABELS, +) + +_TEST_FG_LIST = [_TEST_FG1, _TEST_FG2, _TEST_FG3] diff --git a/tests/unit/vertexai/test_feature_group.py b/tests/unit/vertexai/test_feature_group.py index ba986c3ef1..477d41e0d3 100644 --- a/tests/unit/vertexai/test_feature_group.py +++ b/tests/unit/vertexai/test_feature_group.py @@ -49,6 +49,17 @@ _TEST_FG1_BQ_URI, _TEST_FG1_ENTITY_ID_COLUMNS, _TEST_FG1_LABELS, + _TEST_FG2_ID, + _TEST_FG2_PATH, + _TEST_FG2_BQ_URI, + _TEST_FG2_ENTITY_ID_COLUMNS, + _TEST_FG2_LABELS, + _TEST_FG3_ID, + _TEST_FG3_PATH, + _TEST_FG3_BQ_URI, + _TEST_FG3_ENTITY_ID_COLUMNS, + _TEST_FG3_LABELS, + _TEST_FG_LIST, ) @@ -87,6 +98,16 @@ def create_fg_mock(): yield create_fg_mock +@pytest.fixture +def list_fg_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "list_feature_groups", + ) as list_fg_mock: + list_fg_mock.return_value = _TEST_FG_LIST + yield list_fg_mock + + def fg_eq( fg_to_check: FeatureGroup, name: str, @@ -233,3 +254,42 @@ def test_create_fg( location=_TEST_LOCATION, labels=_TEST_FG1_LABELS, ) + + +def test_list(list_fg_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature_groups = FeatureGroup.list() + + list_fg_mock.assert_called_once_with(request={"parent": _TEST_PARENT}) + assert len(feature_groups) == len(_TEST_FG_LIST) + fg_eq( + feature_groups[0], + name=_TEST_FG1_ID, + resource_name=_TEST_FG1_PATH, + source_uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG1_LABELS, + ) + fg_eq( + feature_groups[1], + name=_TEST_FG2_ID, + resource_name=_TEST_FG2_PATH, + source_uri=_TEST_FG2_BQ_URI, + entity_id_columns=_TEST_FG2_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG2_LABELS, + ) + fg_eq( + feature_groups[2], + name=_TEST_FG3_ID, + resource_name=_TEST_FG3_PATH, + source_uri=_TEST_FG3_BQ_URI, + entity_id_columns=_TEST_FG3_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG3_LABELS, + ) From 6150322dde4aadf93d9e03a6c617d679f5791707 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 8 May 2024 16:17:26 -0700 Subject: [PATCH 19/30] chore: Add batch_dedicated_resources parameter for feature attribution spec in Model Monitoring SDK. PiperOrigin-RevId: 631957771 --- .../preview/ml_monitoring/spec/objective.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vertexai/resources/preview/ml_monitoring/spec/objective.py b/vertexai/resources/preview/ml_monitoring/spec/objective.py index 2546cfe074..d36456de85 100644 --- a/vertexai/resources/preview/ml_monitoring/spec/objective.py +++ b/vertexai/resources/preview/ml_monitoring/spec/objective.py @@ -19,6 +19,7 @@ from google.cloud.aiplatform.compat.types import ( explanation_v1beta1 as explanation, + machine_resources_v1beta1 as machine_resources, model_monitoring_alert_v1beta1 as model_monitoring_alert, model_monitoring_spec_v1beta1 as model_monitoring_spec, ) @@ -156,6 +157,11 @@ class FeatureAttributionSpec: features=["feature1"] default_alert_threshold=0.01, feature_alert_thresholds={"feature1":0.02, "feature2":0.01}, + batch_dedicated_resources=BatchDedicatedResources( + starting_replica_count=1, + max_replica_count=2, + machine_spec=my_machine_spec, + ), ) Attributes: @@ -170,6 +176,10 @@ class FeatureAttributionSpec: feature_alert_thresholds (Dict[str, float]): Optional. Per feature alert threshold will override default alert threshold. + batch_dedicated_resources (machine_resources.BatchDedicatedResources): + Optional. The config of resources used by the Model Monitoring during + the batch explanation for non-AutoML models. If not set, `n1-standard-2` + machine type will be used by default. """ def __init__( @@ -177,10 +187,12 @@ def __init__( features: Optional[List[str]] = None, default_alert_threshold: Optional[float] = None, feature_alert_thresholds: Optional[Dict[str, float]] = None, + batch_dedicated_resources: Optional[machine_resources.BatchDedicatedResources] = None, ): self.features = features self.default_alert_threshold = default_alert_threshold self.feature_alert_thresholds = feature_alert_thresholds + self.batch_dedicated_resources = batch_dedicated_resources def _as_proto( self, @@ -216,6 +228,7 @@ def _as_proto( default_alert_condition=user_default_alert_threshold, feature_alert_conditions=user_alert_thresholds, features=user_features, + batch_explanation_dedicated_resources=self.batch_dedicated_resources, ) ) From cc8bc965932efb68a30db9decb5a24cf597b0d8b Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 8 May 2024 17:29:46 -0700 Subject: [PATCH 20/30] feat: LLM - Text embedding - Added the `output_dimensionality` and `learning_rate_multiplier` parameters to text embedding tuning (Preview only) PiperOrigin-RevId: 631976561 --- tests/unit/aiplatform/test_language_models.py | 16 ++- vertexai/language_models/_language_models.py | 124 ++++++++++++++++-- 2 files changed, 128 insertions(+), 12 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index b6f63f3490..4dc40b1146 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -1661,7 +1661,7 @@ def get_endpoint_mock(): @pytest.fixture def mock_deploy_tuned_embedding_model(get_endpoint_mock): with mock.patch.object( - _language_models._TunableTextEmbeddingModelMixin, "deploy_tuned_model" + _language_models._PreviewTunableTextEmbeddingModelMixin, "deploy_tuned_model" ) as mock_text_generation_model: mock_text_generation_model.return_value._model_id = ( test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME @@ -2289,10 +2289,11 @@ def test_text_generation_response_repr(self): indirect=True, ) @pytest.mark.parametrize( - "base_model_version_id,tune_args,expected_pipeline_args", + "base_model_version_id,use_preview_module,tune_args,expected_pipeline_args", [ # Do not pass any optional parameters. ( "textembedding-gecko@003", + False, dict( training_data="gs://bucket/training.tsv", corpus_data="gs://bucket/corpus.jsonl", @@ -2309,6 +2310,7 @@ def test_text_generation_response_repr(self): # Pass all optional parameters. ( "text-multilingual-embedding-002", + True, dict( training_data="gs://bucket/training.tsv", corpus_data="gs://bucket/corpus.jsonl", @@ -2323,6 +2325,8 @@ def test_text_generation_response_repr(self): accelerator_count=1, machine_type="n1-highmem-16", task_type="DEFAULT", + output_dimensionality=128, + learning_rate_multiplier=0.1, ), dict( train_steps=30, @@ -2339,6 +2343,8 @@ def test_text_generation_response_repr(self): validation_label_path="gs://bucket/validation.tsv", encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, task_type="DEFAULT", + output_dimensionality=128, + learning_rate_multiplier=0.1, ), ), ], @@ -2357,6 +2363,7 @@ def test_tune_text_embedding_model( tune_args, expected_pipeline_args, base_model_version_id, + use_preview_module, ): """Tests tuning the text embedding model.""" aiplatform.init( @@ -2371,7 +2378,10 @@ def test_tune_text_embedding_model( _TEXT_GECKO_PUBLISHER_MODEL_DICT ), ): - model = language_models.TextEmbeddingModel.from_pretrained( + language_models_module = ( + preview_language_models if use_preview_module else language_models + ) + model = language_models_module.TextEmbeddingModel.from_pretrained( base_model_version_id ) tuning_job = model.tune_model(**tune_args) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index c78d93cd49..2080503215 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -239,6 +239,7 @@ def tune_model( accelerator_count: Optional[int] = None, accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None, max_context_length: Optional[str] = None, + output_dimensionality: Optional[int] = None, ) -> "_LanguageModelTuningJob": """Tunes a model based on training data. @@ -273,6 +274,8 @@ def tune_model( accelerator_type: Type of accelerator to use. Type can be "TPU" or "GPU". Type is ignored, if accelerator is specified. max_context_length: The max context length used for tuning. Can be either '8k' or '32k' + output_dimensionality: The output dimensionality of the tuned model, + for text embedding tuning. Returns: A `LanguageModelTuningJob` object that represents the tuning job. @@ -293,6 +296,8 @@ def tune_model( tuning_parameters["batch_size"] = batch_size if train_steps is not None: tuning_parameters["train_steps"] = train_steps + if output_dimensionality is not None: + tuning_parameters["output_dimensionality"] = output_dimensionality if learning_rate is not None: _LOGGER.warning( "The learning_rate parameter is deprecated." @@ -2189,7 +2194,7 @@ async def get_embeddings_async( # for corpus, queries, test and validation data. # TODO(b/625884109): Validate input args, batch_size >0 and train_steps >30, and # task_type must be 'DEFAULT' or None if _model_id is textembedding-gecko@001. -class _TunableTextEmbeddingModelMixin(_TunableModelMixin): +class _PreviewTunableTextEmbeddingModelMixin(_TunableModelMixin): @classmethod def get_tuned_model(cls, *args, **kwargs): del args, kwargs # Unused. @@ -2213,7 +2218,9 @@ def tune_model( machine_type: Optional[str] = None, accelerator: Optional[str] = None, accelerator_count: Optional[int] = None, - ) -> "_LanguageModelTuningJob": + output_dimensionality: Optional[int] = None, + learning_rate_multiplier: Optional[float] = None, + ) -> "_TextEmbeddingModelTuningJob": """Tunes a model based on training data. This method launches and returns an asynchronous model tuning job. @@ -2229,14 +2236,30 @@ def tune_model( queries_data: URI pointing to data in JSON lines format. test_data: URI pointing to data in TSV format. validation_data: URI pointing to data in TSV format. - batch_size: Size of batch. - train_steps: Number of training batches to tune on. + batch_size: The training batch size. + train_steps: The number of steps to perform model tuning. Must + be greater than 30. tuned_model_location: GCP location where the tuned model should be deployed. model_display_name: Custom display name for the tuned model. - task_type: Type of task. Can be "RETRIEVAL_QUERY", "RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", or "FACT_VERIFICATION". - machine_type: Machine type. E.g., "a2-highgpu-1g". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute. - accelerator_count: Count of accelerators. - accelerator: Kind of accelerator. E.g., "NVIDIA_TESLA_A100". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute. + task_type: The task type expected to be used during inference. + Valid values are `DEFAULT`, `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`, + `SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING`, + `FACT_VERIFICATION`, and `QUESTION_ANSWERING`. + machine_type: The machine type to use for training. For information + about selecting the machine type that matches the accelerator + type and count you have selected, see + https://cloud.google.com/compute/docs/gpus. + accelerator: The accelerator type to use for tuning, for example + `NVIDIA_TESLA_V100`. For possible values, see + https://cloud.google.com/vertex-ai/generative-ai/docs/models/tune-embeddings#using-accelerators. + accelerator_count: The number of accelerators to use when training. + Using a greater number of accelerators may make training faster, + but has no effect on quality. + output_dimensionality: The desired embedding dimension of your + tuned model, up to 768. This is only supported for models + `text-embedding-004` and `text-multilingual-embedding-002`. + learning_rate_multiplier: A multiplier to apply to the + recommended learning rate during tuning. Returns: A `LanguageModelTuningJob` object that represents the tuning job. Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object. @@ -2260,6 +2283,8 @@ def tune_model( machine_type=machine_type, accelerator=accelerator, accelerator_count=accelerator_count, + output_dimensionality=output_dimensionality, + learning_rate_multiplier=learning_rate_multiplier, ) def _bundle_up_tuning_job(self, pipeline_job): @@ -2318,14 +2343,95 @@ def deploy_tuned_model( return model +class _TunableTextEmbeddingModelMixin(_PreviewTunableTextEmbeddingModelMixin): + def tune_model( + self, + *, + training_data: Optional[str] = None, + corpus_data: Optional[str] = None, + queries_data: Optional[str] = None, + test_data: Optional[str] = None, + validation_data: Optional[str] = None, + batch_size: Optional[int] = None, + train_steps: Optional[int] = None, + tuned_model_location: Optional[str] = None, + model_display_name: Optional[str] = None, + task_type: Optional[str] = None, + machine_type: Optional[str] = None, + accelerator: Optional[str] = None, + accelerator_count: Optional[int] = None, + ) -> "_TextEmbeddingModelTuningJob": + """Tunes a model based on training data. + + This method launches and returns an asynchronous model tuning job. + Usage: + ``` + tuning_job = model.tune_model(...) + ... do some other work + tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete + + Args: + training_data: URI pointing to training data in TSV format. + corpus_data: URI pointing to data in JSON lines format. + queries_data: URI pointing to data in JSON lines format. + test_data: URI pointing to data in TSV format. + validation_data: URI pointing to data in TSV format. + batch_size: The training batch size. + train_steps: The number of steps to perform model tuning. Must + be greater than 30. + tuned_model_location: GCP location where the tuned model should be deployed. + model_display_name: Custom display name for the tuned model. + task_type: The task type expected to be used during inference. + Valid values are `DEFAULT`, `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`, + `SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING`, + `FACT_VERIFICATION`, and `QUESTION_ANSWERING`. + machine_type: The machine type to use for training. For information + about selecting the machine type that matches the accelerator + type and count you have selected, see + https://cloud.google.com/compute/docs/gpus. + accelerator: The accelerator type to use for tuning, for example + `NVIDIA_TESLA_V100`. For possible values, see + https://cloud.google.com/vertex-ai/generative-ai/docs/models/tune-embeddings#using-accelerators. + accelerator_count: The number of accelerators to use when training. + Using a greater number of accelerators may make training faster, + but has no effect on quality. + Returns: + A `LanguageModelTuningJob` object that represents the tuning job. + Calling `job.result()` blocks until the tuning is complete and + returns a `LanguageModel` object. + + Raises: + ValueError: If the provided parameter combinations or values are not + supported. + RuntimeError: If the model does not support tuning + """ + + return super().tune_model( + training_data=training_data, + corpus_data=corpus_data, + queries_data=queries_data, + test_data=test_data, + validation_data=validation_data, + task_type=task_type, + batch_size=batch_size, + train_steps=train_steps, + tuned_model_location=tuned_model_location, + model_display_name=model_display_name, + machine_type=machine_type, + accelerator=accelerator, + accelerator_count=accelerator_count, + ) + + class TextEmbeddingModel(_TextEmbeddingModel, _TunableTextEmbeddingModelMixin): __module__ = "vertexai.language_models" class _PreviewTextEmbeddingModel( - TextEmbeddingModel, + _TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin, + _PreviewTunableTextEmbeddingModelMixin, ): __name__ = "TextEmbeddingModel" __module__ = "vertexai.preview.language_models" From e0c6227d0dd92d83c98cc3c7e7607fd252e74a32 Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Wed, 8 May 2024 19:20:50 -0700 Subject: [PATCH 21/30] feat: Support custom service account for Ray cluster creation and Ray Client connection PiperOrigin-RevId: 631998839 --- .../preview/vertex_ray/client_builder.py | 23 ++--- .../preview/vertex_ray/cluster_init.py | 20 ++++- .../preview/vertex_ray/util/_gapic_utils.py | 5 +- .../preview/vertex_ray/util/resources.py | 5 +- tests/unit/vertex_ray/test_cluster_init.py | 62 +++++++++++++ tests/unit/vertex_ray/test_constants.py | 89 +++++++++++++++++-- .../unit/vertex_ray/test_vertex_ray_client.py | 63 +++++++++++++ 7 files changed, 248 insertions(+), 19 deletions(-) diff --git a/google/cloud/aiplatform/preview/vertex_ray/client_builder.py b/google/cloud/aiplatform/preview/vertex_ray/client_builder.py index c76147ab25..ff67a0528b 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/client_builder.py +++ b/google/cloud/aiplatform/preview/vertex_ray/client_builder.py @@ -98,8 +98,21 @@ def __init__(self, address: Optional[str]) -> None: public_address = self.response.resource_runtime.access_uris.get( "RAY_CLIENT_ENDPOINT" ) + service_account = ( + self.response.resource_runtime_spec.service_account_spec.service_account + ) + if public_address is None: address = private_address + if service_account: + raise ValueError( + "[Ray on Vertex AI]: Ray Cluster ", + address, + " failed to start Head node properly because custom service" + " account isn't supported in peered VPC network. Use public" + " endpoint instead (createa a cluster withought specifying" + " VPC network).", + ) else: address = public_address @@ -110,17 +123,7 @@ def __init__(self, address: Optional[str]) -> None: persistent_resource_id, " Head node is not reachable. Please ensure that a valid VPC network has been specified.", ) - # Handling service_account - service_account = ( - self.response.resource_runtime_spec.service_account_spec.service_account - ) - if service_account: - raise ValueError( - "[Ray on Vertex AI]: Ray Cluster ", - address, - " failed to start Head node properly because custom service account isn't supported.", - ) logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address) cluster = _gapic_utils.persistent_resource_to_cluster( persistent_resource=self.response diff --git a/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py b/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py index 2710512cd5..d0daf72cd7 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py +++ b/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py @@ -32,6 +32,7 @@ RayMetricSpec, ResourcePool, ResourceRuntimeSpec, + ServiceAccountSpec, ) from google.cloud.aiplatform.preview.vertex_ray.util import ( @@ -48,6 +49,7 @@ def create_ray_cluster( python_version: Optional[str] = "3.10", ray_version: Optional[str] = "2.9", network: Optional[str] = None, + service_account: Optional[str] = None, cluster_name: Optional[str] = None, worker_node_types: Optional[List[resources.Resources]] = None, custom_images: Optional[resources.NodeImages] = None, @@ -78,7 +80,9 @@ def create_ray_cluster( cluster_resource_name = vertex_ray.create_ray_cluster( head_node_type=head_node_type, - network="projects/my-project-number/global/networks/my-vpc-name", + network="projects/my-project-number/global/networks/my-vpc-name", # Optional + service_account="my-service-account@my-project-number.iam.gserviceaccount.com", # Optional + cluster_name="my-cluster-name", # Optional worker_node_types=worker_node_types, ray_version="2.9", ) @@ -100,6 +104,8 @@ def create_ray_cluster( Vertex API service. For Ray Job API, VPC network is not required because Ray Cluster connection can be accessed through dashboard address. + service_account: Service account to be used for running Ray programs on + the cluster. cluster_name: This value may be up to 63 characters, and valid characters are `[a-z0-9_-]`. The first character cannot be a number or hyphen. @@ -254,7 +260,17 @@ def create_ray_cluster( ray_spec = RaySpec( resource_pool_images=resource_pool_images, ray_metric_spec=ray_metric_spec ) - resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec) + if service_account: + service_account_spec = ServiceAccountSpec( + enable_custom_service_account=True, + service_account=service_account, + ) + resource_runtime_spec = ResourceRuntimeSpec( + ray_spec=ray_spec, + service_account_spec=service_account_spec, + ) + else: + resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec) persistent_resource = PersistentResource( resource_pools=resource_pools, network=network, diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py b/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py index 4bed9101c7..c1469b1ee0 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py +++ b/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py @@ -166,7 +166,10 @@ def persistent_resource_to_cluster( head_image_uri = ( persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[head_id] ) - + if persistent_resource.resource_runtime_spec.service_account_spec.service_account: + cluster.service_account = ( + persistent_resource.resource_runtime_spec.service_account_spec.service_account + ) if not head_image_uri: head_image_uri = persistent_resource.resource_runtime_spec.ray_spec.image_uri diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/resources.py b/google/cloud/aiplatform/preview/vertex_ray/util/resources.py index 2830ac37b3..5575edbaf7 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/util/resources.py +++ b/google/cloud/aiplatform/preview/vertex_ray/util/resources.py @@ -41,7 +41,7 @@ class Resources: us-docker.pkg.dev/my-project/ray-gpu.2-9.py310-tf:latest). """ - machine_type: Optional[str] = "n1-standard-8" + machine_type: Optional[str] = "n1-standard-16" node_count: Optional[int] = 1 accelerator_type: Optional[str] = None accelerator_count: Optional[int] = 0 @@ -81,6 +81,8 @@ class Cluster: managed in the Vertex API service. For Ray Job API, VPC network is not required because cluster connection can be accessed through dashboard address. + service_account: Service account to be used for running Ray programs on + the cluster. state: Describes the cluster state (defined in PersistentResource.State). python_version: Python version for the ray cluster (e.g. "3.10"). ray_version: Ray version for the ray cluster (e.g. "2.4"). @@ -102,6 +104,7 @@ class Cluster: cluster_resource_name: str = None network: str = None + service_account: str = None state: PersistentResource.State = None python_version: str = None ray_version: str = None diff --git a/tests/unit/vertex_ray/test_cluster_init.py b/tests/unit/vertex_ray/test_cluster_init.py index b23d5efa14..d7d51f5a22 100644 --- a/tests/unit/vertex_ray/test_cluster_init.py +++ b/tests/unit/vertex_ray/test_cluster_init.py @@ -92,6 +92,34 @@ def get_persistent_resource_1_pool_custom_image_mock(): yield get_persistent_resource_1_pool_custom_image_mock +@pytest.fixture +def create_persistent_resource_1_pool_byosa_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "create_persistent_resource", + ) as create_persistent_resource_1_pool_byosa_mock: + create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation) + create_persistent_resource_lro_mock.result.return_value = ( + tc.ClusterConstants.TEST_RESPONSE_RUNNING_1_POOL_BYOSA + ) + create_persistent_resource_1_pool_byosa_mock.return_value = ( + create_persistent_resource_lro_mock + ) + yield create_persistent_resource_1_pool_byosa_mock + + +@pytest.fixture +def get_persistent_resource_1_pool_byosa_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as get_persistent_resource_1_pool_byosa_mock: + get_persistent_resource_1_pool_byosa_mock.return_value = ( + tc.ClusterConstants.TEST_RESPONSE_RUNNING_1_POOL_BYOSA + ) + yield get_persistent_resource_1_pool_byosa_mock + + @pytest.fixture def create_persistent_resource_2_pools_mock(): with mock.patch.object( @@ -426,6 +454,30 @@ def test_create_ray_cluster_initialized_success( ] ) + @pytest.mark.usefixtures("get_persistent_resource_1_pool_byosa_mock") + def test_create_ray_cluster_byosa_success( + self, create_persistent_resource_1_pool_byosa_mock + ): + """If head and worker nodes are duplicate, merge to head pool.""" + cluster_name = vertex_ray.create_ray_cluster( + head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_1_POOL, + worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_1_POOL, + service_account=tc.ProjectConstants.TEST_SERVICE_ACCOUNT, + cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID, + ) + + assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name + + request = persistent_resource_service.CreatePersistentResourceRequest( + parent=tc.ProjectConstants.TEST_PARENT, + persistent_resource=tc.ClusterConstants.TEST_REQUEST_RUNNING_1_POOL_BYOSA, + persistent_resource_id=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID, + ) + + create_persistent_resource_1_pool_byosa_mock.assert_called_with( + request, + ) + def test_create_ray_cluster_head_multinode_error(self): with pytest.raises(ValueError) as e: vertex_ray.create_ray_cluster( @@ -508,6 +560,16 @@ def test_get_ray_cluster_with_custom_image_success( get_persistent_resource_2_pools_custom_image_mock.assert_called_once() cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE) + def test_get_ray_cluster_byosa_success( + self, get_persistent_resource_1_pool_byosa_mock + ): + cluster = vertex_ray.get_ray_cluster( + cluster_resource_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS + ) + + get_persistent_resource_1_pool_byosa_mock.assert_called_once() + cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER_BYOSA) + @pytest.mark.usefixtures("get_persistent_resource_exception_mock") def test_get_ray_cluster_error(self): with pytest.raises(ValueError) as e: diff --git a/tests/unit/vertex_ray/test_constants.py b/tests/unit/vertex_ray/test_constants.py index d33ed14bc1..7154c05e69 100644 --- a/tests/unit/vertex_ray/test_constants.py +++ b/tests/unit/vertex_ray/test_constants.py @@ -16,6 +16,7 @@ # import dataclasses +import sys from google.cloud.aiplatform.preview.vertex_ray.util.resources import Cluster from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( @@ -28,10 +29,10 @@ from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( PersistentResource, ) -from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( RayMetricSpec, ) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( ResourcePool, ) @@ -41,9 +42,11 @@ from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( ResourceRuntimeSpec, ) - +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + ServiceAccountSpec, +) import pytest -import sys + rovminversion = pytest.mark.skipif( sys.version_info > (3, 10), reason="Requires python3.10 or lower" @@ -67,6 +70,7 @@ class ProjectConstants: TEST_MODEL_ID = ( f"projects/{TEST_GCP_PROJECT_NUMBER}/locations/{TEST_GCP_REGION}/models/456" ) + TEST_SERVICE_ACCOUNT = "service-account@project.iam.gserviceaccount.com" @dataclasses.dataclass(frozen=True) @@ -79,6 +83,9 @@ class ClusterConstants: TEST_VERTEX_RAY_DASHBOARD_ADDRESS = ( "48b400ad90b8dd3c-dot-us-central1.aiplatform-training.googleusercontent.com" ) + TEST_VERTEX_RAY_CLIENT_ENDPOINT = ( + "88888.us-central1-1234567.staging-ray.vertexai.goog:443" + ) TEST_VERTEX_RAY_PR_ID = "user-persistent-resource-1234567890" TEST_VERTEX_RAY_PR_ADDRESS = ( f"{ProjectConstants.TEST_PARENT}/persistentResources/" + TEST_VERTEX_RAY_PR_ID @@ -106,7 +113,7 @@ class ClusterConstants: TEST_RESOURCE_POOL_0 = ResourcePool( id="head-node", machine_spec=MachineSpec( - machine_type="n1-standard-8", + machine_type="n1-standard-16", accelerator_type="NVIDIA_TESLA_P100", accelerator_count=1, ), @@ -147,6 +154,20 @@ class ClusterConstants: ), network=ProjectConstants.TEST_VPC_NETWORK, ) + TEST_REQUEST_RUNNING_1_POOL_BYOSA = PersistentResource( + resource_pools=[TEST_RESOURCE_POOL_0], + resource_runtime_spec=ResourceRuntimeSpec( + ray_spec=RaySpec( + resource_pool_images={"head-node": TEST_GPU_IMAGE}, + ray_metric_spec=RayMetricSpec(disabled=False), + ), + service_account_spec=ServiceAccountSpec( + enable_custom_service_account=True, + service_account=ProjectConstants.TEST_SERVICE_ACCOUNT, + ), + ), + network=None, + ) # Get response has generated name, and URIs TEST_RESPONSE_RUNNING_1_POOL = PersistentResource( name=TEST_VERTEX_RAY_PR_ADDRESS, @@ -185,6 +206,50 @@ class ClusterConstants: ), state="RUNNING", ) + TEST_RESPONSE_RUNNING_1_POOL_BYOSA = PersistentResource( + name=TEST_VERTEX_RAY_PR_ADDRESS, + resource_pools=[TEST_RESOURCE_POOL_0], + resource_runtime_spec=ResourceRuntimeSpec( + ray_spec=RaySpec( + resource_pool_images={"head-node": TEST_GPU_IMAGE}, + ray_metric_spec=RayMetricSpec(disabled=False), + ), + service_account_spec=ServiceAccountSpec( + enable_custom_service_account=True, + service_account=ProjectConstants.TEST_SERVICE_ACCOUNT, + ), + ), + network=None, + resource_runtime=ResourceRuntime( + access_uris={ + "RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + "RAY_CLIENT_ENDPOINT": TEST_VERTEX_RAY_CLIENT_ENDPOINT, + } + ), + state="RUNNING", + ) + TEST_RESPONSE_1_POOL_BYOSA_PRIVATE = PersistentResource( + name=TEST_VERTEX_RAY_PR_ADDRESS, + resource_pools=[TEST_RESOURCE_POOL_0], + resource_runtime_spec=ResourceRuntimeSpec( + ray_spec=RaySpec( + resource_pool_images={"head-node": TEST_GPU_IMAGE}, + ray_metric_spec=RayMetricSpec(disabled=False), + ), + service_account_spec=ServiceAccountSpec( + enable_custom_service_account=True, + service_account=ProjectConstants.TEST_SERVICE_ACCOUNT, + ), + ), + network=ProjectConstants.TEST_VPC_NETWORK, + resource_runtime=ResourceRuntime( + access_uris={ + "RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + "RAY_CLIENT_ENDPOINT": TEST_VERTEX_RAY_CLIENT_ENDPOINT, + } + ), + state="RUNNING", + ) # 2_POOL: worker_node_types and head_node_type have different MachineSpecs TEST_HEAD_NODE_TYPE_2_POOLS = Resources() TEST_WORKER_NODE_TYPES_2_POOLS = [ @@ -208,7 +273,7 @@ class ClusterConstants: TEST_RESOURCE_POOL_1 = ResourcePool( id="head-node", machine_spec=MachineSpec( - machine_type="n1-standard-8", + machine_type="n1-standard-16", ), disk_spec=DiskSpec( boot_disk_type="pd-ssd", @@ -302,6 +367,7 @@ class ClusterConstants: python_version="3.10", ray_version="2.9", network=ProjectConstants.TEST_VPC_NETWORK, + service_account=None, state="RUNNING", head_node_type=TEST_HEAD_NODE_TYPE_1_POOL, worker_node_types=TEST_WORKER_NODE_TYPES_1_POOL, @@ -312,6 +378,7 @@ class ClusterConstants: python_version="3.10", ray_version="2.9", network=ProjectConstants.TEST_VPC_NETWORK, + service_account=None, state="RUNNING", head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS, worker_node_types=TEST_WORKER_NODE_TYPES_2_POOLS, @@ -320,11 +387,23 @@ class ClusterConstants: TEST_CLUSTER_CUSTOM_IMAGE = Cluster( cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS, network=ProjectConstants.TEST_VPC_NETWORK, + service_account=None, state="RUNNING", head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE, worker_node_types=TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE, dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS, ) + TEST_CLUSTER_BYOSA = Cluster( + cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS, + python_version="3.10", + ray_version="2.9", + network="", + service_account=ProjectConstants.TEST_SERVICE_ACCOUNT, + state="RUNNING", + head_node_type=TEST_HEAD_NODE_TYPE_1_POOL, + worker_node_types=TEST_WORKER_NODE_TYPES_1_POOL, + dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + ) TEST_BEARER_TOKEN = "test-bearer-token" TEST_HEADERS = { "Content-Type": "application/json", diff --git a/tests/unit/vertex_ray/test_vertex_ray_client.py b/tests/unit/vertex_ray/test_vertex_ray_client.py index 8150020464..26b56ee6a4 100644 --- a/tests/unit/vertex_ray/test_vertex_ray_client.py +++ b/tests/unit/vertex_ray/test_vertex_ray_client.py @@ -43,6 +43,17 @@ ray_client_context=_TEST_CLIENT_CONTEXT, ) +_TEST_VERTEX_RAY_CLIENT_CONTEXT_PUBLIC = ( + vertex_ray.client_builder._VertexRayClientContext( + persistent_resource_id="MOCK_PERSISTENT_RESOURCE_ID", + ray_head_uris={ + "RAY_DASHBOARD_URI": tc.ClusterConstants.TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + "RAY_CLIENT_ENDPOINT": tc.ClusterConstants.TEST_VERTEX_RAY_CLIENT_ENDPOINT, + }, + ray_client_context=_TEST_CLIENT_CONTEXT, + ) +) + @pytest.fixture def ray_client_init_mock(): @@ -76,6 +87,26 @@ def get_persistent_resource_status_running_no_ray_mock(): yield resolve_head_ip +@pytest.fixture +def get_persistent_resource_status_running_byosa_public_mock(): + with mock.patch.object( + vertex_ray.util._gapic_utils, "get_persistent_resource" + ) as resolve_head_ip: + resolve_head_ip.return_value = tc.ClusterConstants.TEST_RESPONSE_1_POOL_BYOSA + yield resolve_head_ip + + +@pytest.fixture +def get_persistent_resource_status_running_byosa_private_mock(): + with mock.patch.object( + vertex_ray.util._gapic_utils, "get_persistent_resource" + ) as resolve_head_ip: + resolve_head_ip.return_value = ( + tc.ClusterConstants.TEST_RESPONSE_1_POOL_BYOSA_PRIVATE + ) + yield resolve_head_ip + + class TestClientBuilder: def setup_method(self): importlib.reload(aiplatform.initializer) @@ -143,6 +174,38 @@ def test_connect_running_no_ray(self, ray_client_connect_mock): ray_client_connect_mock.assert_called_once_with() assert str(exception.value) == expected_message + @tc.rovminversion + @pytest.mark.usefixtures("get_persistent_resource_status_running_byosa_public_mock") + def test_connect_running_byosa_public(self, ray_client_connect_mock): + connect_result = vertex_ray.ClientBuilder( + tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS + ).connect() + ray_client_connect_mock.assert_called_once_with() + assert connect_result == _TEST_VERTEX_RAY_CLIENT_CONTEXT_PUBLIC + assert ( + connect_result.persistent_resource_id + == tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID + ) + + @tc.rovminversion + @pytest.mark.usefixtures( + "get_persistent_resource_status_running_byosa_private_mock" + ) + def test_connect_running_byosa_private(self, ray_client_connect_mock): + expected_message = ( + "Ray Cluster ", + tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID, + " failed to start Head node properly because custom service" + " account isn't supported in peered VPC network. ", + ) + with pytest.raises(ValueError) as exception: + vertex_ray.ClientBuilder( + tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS + ).connect() + + ray_client_connect_mock.assert_called_once_with() + assert str(exception.value) == expected_message + @tc.rovminversion @pytest.mark.parametrize( "address", From 6557d88eb73624c8dbc7da33db129f7cbdae8a06 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Thu, 9 May 2024 08:10:25 -0700 Subject: [PATCH 22/30] feat: Added the `vision_models.Image._mime_type` property to make `vision_models.Image` compatible with `generative_models.Image` - This will allow `generative_models.Part.from_image` to accept `vision_models.Image` objects. - Added `vision_models.Video._mime_type` - Fixed linter errors. PiperOrigin-RevId: 632153540 --- vertexai/vision_models/_vision_models.py | 75 +++++++++++++++++------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py index daaf3356fc..32c8685a57 100644 --- a/vertexai/vision_models/_vision_models.py +++ b/vertexai/vision_models/_vision_models.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# pylint: disable=bad-continuation, line-too-long, protected-access """Classes for working with vision models.""" import base64 @@ -99,15 +100,22 @@ def load_from_file(location: str) -> "Image": image = Image(image_bytes=image_bytes) return image + @property + def _blob(self) -> storage.Blob: + if self._gcs_uri is None: + raise AttributeError("_blob is only supported when gcs_uri is set.") + storage_client = storage.Client( + credentials=aiplatform_initializer.global_config.credentials + ) + blob = storage.Blob.from_string(uri=self._gcs_uri, client=storage_client) + # Needed to populate `blob.content_type` + blob.reload() + return blob + @property def _image_bytes(self) -> bytes: if self._loaded_bytes is None: - storage_client = storage.Client( - credentials=aiplatform_initializer.global_config.credentials - ) - self._loaded_bytes = storage.Blob.from_string( - uri=self._gcs_uri, client=storage_client - ).download_as_bytes() + self._loaded_bytes = self._blob.download_as_bytes() return self._loaded_bytes @_image_bytes.setter @@ -117,6 +125,10 @@ def _image_bytes(self, value: bytes): @property def _pil_image(self) -> "PIL_Image.Image": if self._loaded_image is None: + if not PIL_Image: + raise RuntimeError( + "The PIL module is not available. Please install the Pillow package." + ) self._loaded_image = PIL_Image.open(io.BytesIO(self._image_bytes)) return self._loaded_image @@ -124,6 +136,16 @@ def _pil_image(self) -> "PIL_Image.Image": def _size(self): return self._pil_image.size + @property + def _mime_type(self) -> str: + """Returns the MIME type of the image.""" + if self._gcs_uri: + return self._blob.content_type + if PIL_Image: + return PIL_Image.MIME.get(self._pil_image.format, "image/jpeg") + # Fall back to jpeg + return "image/jpeg" + def show(self): """Shows the image. @@ -146,7 +168,7 @@ def _as_base64_string(self) -> str: Returns: Base64 encoding of the image as a string. """ - # ! b64encode returns `bytes` object, not ``str. + # ! b64encode returns `bytes` object, not `str`. # We need to convert `bytes` to `str`, otherwise we get service error: # "received initial metadata size exceeds limit" return base64.b64encode(self._image_bytes).decode("ascii") @@ -196,21 +218,36 @@ def load_from_file(location: str) -> "Video": video = Video(video_bytes=video_bytes) return video + @property + def _blob(self) -> storage.Blob: + if self._gcs_uri is None: + raise AttributeError("_blob is only supported when gcs_uri is set.") + storage_client = storage.Client( + credentials=aiplatform_initializer.global_config.credentials + ) + blob = storage.Blob.from_string(uri=self._gcs_uri, client=storage_client) + # Needed to populate `blob.content_type` + blob.reload() + return blob + @property def _video_bytes(self) -> bytes: if self._loaded_bytes is None: - storage_client = storage.Client( - credentials=aiplatform_initializer.global_config.credentials - ) - self._loaded_bytes = storage.Blob.from_string( - uri=self._gcs_uri, client=storage_client - ).download_as_bytes() + self._loaded_bytes = self._blob.download_as_bytes() return self._loaded_bytes @_video_bytes.setter def _video_bytes(self, value: bytes): self._loaded_bytes = value + @property + def _mime_type(self) -> str: + """Returns the MIME type of the video.""" + if self._gcs_uri: + return self._blob.content_type + # Fall back to mp4 + return "video/mp4" + def save(self, location: str): """Saves video to a file. @@ -225,7 +262,7 @@ def _as_base64_string(self) -> str: Returns: Base64 encoding of the video as a string. """ - # ! b64encode returns `bytes` object, not ``str. + # ! b64encode returns `bytes` object, not `str`. # We need to convert `bytes` to `str`, otherwise we get service error: # "received initial metadata size exceeds limit" return base64.b64encode(self._video_bytes).decode("ascii") @@ -582,8 +619,7 @@ def generate_images( * "16:9" : 16:9 aspect ratio * "4:3" : 4:3 aspect ratio * "3:4" : 3:4 aspect_ratio - guidance_scale: Controls the strength of the prompt. Suggested values - are: + guidance_scale: Controls the strength of the prompt. Suggested values are: * 0-9 (low strength) * 10-20 (medium strength) * 21+ (high strength) @@ -667,8 +703,7 @@ def edit_image( * 0-9 (low strength) * 10-20 (medium strength) * 21+ (high strength) - edit_mode: Describes the editing mode for the request. Supported values - are: + edit_mode: Describes the editing mode for the request. Supported values are: * inpainting-insert: fills the mask area based on the text prompt (requires mask and text) * inpainting-remove: removes the object(s) in the mask area. @@ -677,7 +712,6 @@ def edit_image( (Requires mask) * product-image: Changes the background for the predominant product or subject in the image - segmentation_classes: List of class IDs for segmentation. Max of 5 IDs mask_mode: Solicits generation of the mask (v/s providing mask as an input). Supported values are: * background: Automatically generates a mask for all regions except @@ -686,6 +720,7 @@ def edit_image( subjects(s) of the image. * semantic: Segment one or more of the segmentation classes using class ID + segmentation_classes: List of class IDs for segmentation. Max of 5 IDs mask_dilation: Defines the dilation percentage of the mask provided. Float between 0 and 1. Defaults to 0.03 product_position: Defines whether the product should stay fixed or be @@ -1241,7 +1276,7 @@ class WatermarkVerificationResponse: class WatermarkVerificationModel(_model_garden_models._ModelGardenModel): - """Verifies if an image has a watermark""" + """Verifies if an image has a watermark.""" __module__ = "vertexai.preview.vision_models" From c528b6ff44e2347797336db800ca01240e670d32 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 9 May 2024 11:04:29 -0700 Subject: [PATCH 23/30] fix: a bug in the evaluation library where the job crashes if only custom metrics are specified. PiperOrigin-RevId: 632207470 --- vertexai/preview/evaluation/_evaluation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vertexai/preview/evaluation/_evaluation.py b/vertexai/preview/evaluation/_evaluation.py index c0e5464f8d..d396f9460f 100644 --- a/vertexai/preview/evaluation/_evaluation.py +++ b/vertexai/preview/evaluation/_evaluation.py @@ -534,7 +534,8 @@ async def _compute_metrics( metric_name = metric tasks_by_metric[metric_name].append(task) - api_request_count = len(tasks_by_metric) * len(next(iter(tasks_by_metric.values()))) + api_request_count = (len(api_metrics) + len(custom_metrics)) * len( + evaluation_run_config.dataset) _LOGGER.info( f"Computing metrics with a total of {api_request_count} Vertex online" " evaluation service requests." From f7c51327c49d000cc79d56bb5333ed7fea28fa01 Mon Sep 17 00:00:00 2001 From: Matthew Tang Date: Thu, 9 May 2024 12:45:39 -0700 Subject: [PATCH 24/30] feat: Release Ray on Vertex SDK to GA PiperOrigin-RevId: 632239389 --- .../aiplatform/preview/vertex_ray/__init__.py | 12 ++-- .../vertex_ray/predict/sklearn/__init__.py | 4 +- .../vertex_ray/predict/tensorflow/__init__.py | 4 +- .../vertex_ray/predict/torch/__init__.py | 4 +- .../vertex_ray/predict/xgboost/__init__.py | 4 +- .../cloud/aiplatform/vertex_ray/__init__.py | 64 +++++++++++++++++++ .../vertex_ray/bigquery_datasink.py | 0 .../vertex_ray/bigquery_datasource.py | 0 .../vertex_ray/client_builder.py | 0 .../{preview => }/vertex_ray/cluster_init.py | 2 +- .../{preview => }/vertex_ray/dashboard_sdk.py | 0 .../{preview => }/vertex_ray/data.py | 6 +- .../aiplatform/vertex_ray/predict/__init__.py | 18 ++++++ .../vertex_ray/predict/sklearn/__init__.py | 22 +++++++ .../vertex_ray/predict/sklearn/register.py | 6 +- .../vertex_ray/predict/tensorflow/__init__.py | 22 +++++++ .../vertex_ray/predict/tensorflow/register.py | 6 +- .../vertex_ray/predict/torch/__init__.py | 22 +++++++ .../vertex_ray/predict/torch/register.py | 2 +- .../vertex_ray/predict/util/constants.py | 0 .../vertex_ray/predict/util/predict_utils.py | 0 .../vertex_ray/predict/xgboost/__init__.py | 22 +++++++ .../vertex_ray/predict/xgboost/register.py | 6 +- .../{preview => }/vertex_ray/render.py | 0 .../vertex_ray/templates/context.html.j2 | 0 .../templates/context_shellurirow.html.j2 | 0 .../vertex_ray/util/_gapic_utils.py | 4 +- .../vertex_ray/util/_validation_utils.py | 0 .../vertex_ray/util/resources.py | 0 setup.py | 6 +- .../vertex_ray/test_cluster_management.py | 2 +- .../test_job_submission_dashboard.py | 4 +- tests/system/vertex_ray/test_ray_data.py | 4 +- tests/unit/vertex_ray/conftest.py | 2 +- tests/unit/vertex_ray/test_bigquery.py | 4 +- tests/unit/vertex_ray/test_cluster_init.py | 4 +- tests/unit/vertex_ray/test_constants.py | 4 +- tests/unit/vertex_ray/test_dashboard_sdk.py | 2 +- tests/unit/vertex_ray/test_ray_prediction.py | 8 +-- tests/unit/vertex_ray/test_ray_utils.py | 2 +- .../unit/vertex_ray/test_vertex_ray_client.py | 2 +- 41 files changed, 226 insertions(+), 48 deletions(-) create mode 100644 google/cloud/aiplatform/vertex_ray/__init__.py rename google/cloud/aiplatform/{preview => }/vertex_ray/bigquery_datasink.py (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/bigquery_datasource.py (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/client_builder.py (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/cluster_init.py (99%) rename google/cloud/aiplatform/{preview => }/vertex_ray/dashboard_sdk.py (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/data.py (91%) create mode 100644 google/cloud/aiplatform/vertex_ray/predict/__init__.py create mode 100644 google/cloud/aiplatform/vertex_ray/predict/sklearn/__init__.py rename google/cloud/aiplatform/{preview => }/vertex_ray/predict/sklearn/register.py (96%) create mode 100644 google/cloud/aiplatform/vertex_ray/predict/tensorflow/__init__.py rename google/cloud/aiplatform/{preview => }/vertex_ray/predict/tensorflow/register.py (96%) create mode 100644 google/cloud/aiplatform/vertex_ray/predict/torch/__init__.py rename google/cloud/aiplatform/{preview => }/vertex_ray/predict/torch/register.py (97%) rename google/cloud/aiplatform/{preview => }/vertex_ray/predict/util/constants.py (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/predict/util/predict_utils.py (100%) create mode 100644 google/cloud/aiplatform/vertex_ray/predict/xgboost/__init__.py rename google/cloud/aiplatform/{preview => }/vertex_ray/predict/xgboost/register.py (96%) rename google/cloud/aiplatform/{preview => }/vertex_ray/render.py (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/templates/context.html.j2 (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/templates/context_shellurirow.html.j2 (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/util/_gapic_utils.py (98%) rename google/cloud/aiplatform/{preview => }/vertex_ray/util/_validation_utils.py (100%) rename google/cloud/aiplatform/{preview => }/vertex_ray/util/resources.py (100%) diff --git a/google/cloud/aiplatform/preview/vertex_ray/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/__init__.py index dff34299a8..8e58f0e7da 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/__init__.py +++ b/google/cloud/aiplatform/preview/vertex_ray/__init__.py @@ -18,14 +18,14 @@ # import sys -from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasource import ( +from google.cloud.aiplatform.vertex_ray.bigquery_datasource import ( BigQueryDatasource, ) -from google.cloud.aiplatform.preview.vertex_ray.client_builder import ( +from google.cloud.aiplatform.vertex_ray.client_builder import ( VertexRayClientBuilder as ClientBuilder, ) -from google.cloud.aiplatform.preview.vertex_ray.cluster_init import ( +from google.cloud.aiplatform.vertex_ray.cluster_init import ( create_ray_cluster, delete_ray_cluster, get_ray_cluster, @@ -33,14 +33,14 @@ update_ray_cluster, ) -from google.cloud.aiplatform.preview.vertex_ray import data +from google.cloud.aiplatform.vertex_ray import data -from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( +from google.cloud.aiplatform.vertex_ray.util.resources import ( Resources, NodeImages, ) -from google.cloud.aiplatform.preview.vertex_ray.dashboard_sdk import ( +from google.cloud.aiplatform.vertex_ray.dashboard_sdk import ( get_job_submission_client_cluster_info, ) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/__init__.py index 856fc73fe7..fa11436f1b 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/__init__.py +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/__init__.py @@ -17,6 +17,8 @@ # limitations under the License. # -from .register import register_sklearn +from google.cloud.aiplatform.vertex_ray.predict.sklearn import ( + register_sklearn, +) __all__ = ("register_sklearn",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/__init__.py index a67539b753..6d62008295 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/__init__.py +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/__init__.py @@ -17,6 +17,8 @@ # limitations under the License. # -from .register import register_tensorflow +from google.cloud.aiplatform.vertex_ray.predict.tensorflow import ( + register_tensorflow, +) __all__ = ("register_tensorflow",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/torch/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/torch/__init__.py index 175fcd90fa..5f08340810 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/predict/torch/__init__.py +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/torch/__init__.py @@ -17,6 +17,8 @@ # limitations under the License. # -from .register import get_pytorch_model_from +from google.cloud.aiplatform.vertex_ray.predict.torch import ( + get_pytorch_model_from, +) __all__ = ("get_pytorch_model_from",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/__init__.py index d98b638879..019fbe6d24 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/__init__.py +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/__init__.py @@ -17,6 +17,8 @@ # limitations under the License. # -from .register import register_xgboost +from google.cloud.aiplatform.vertex_ray.predict.xgboost import ( + register_xgboost, +) __all__ = ("register_xgboost",) diff --git a/google/cloud/aiplatform/vertex_ray/__init__.py b/google/cloud/aiplatform/vertex_ray/__init__.py new file mode 100644 index 0000000000..8e58f0e7da --- /dev/null +++ b/google/cloud/aiplatform/vertex_ray/__init__.py @@ -0,0 +1,64 @@ +"""Ray on Vertex AI.""" + +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys + +from google.cloud.aiplatform.vertex_ray.bigquery_datasource import ( + BigQueryDatasource, +) +from google.cloud.aiplatform.vertex_ray.client_builder import ( + VertexRayClientBuilder as ClientBuilder, +) + +from google.cloud.aiplatform.vertex_ray.cluster_init import ( + create_ray_cluster, + delete_ray_cluster, + get_ray_cluster, + list_ray_clusters, + update_ray_cluster, +) + +from google.cloud.aiplatform.vertex_ray import data + +from google.cloud.aiplatform.vertex_ray.util.resources import ( + Resources, + NodeImages, +) + +from google.cloud.aiplatform.vertex_ray.dashboard_sdk import ( + get_job_submission_client_cluster_info, +) + +if sys.version_info[1] != 10: + print( + "[Ray on Vertex]: The client environment with Python version 3.10 is required." + ) + +__all__ = ( + "BigQueryDatasource", + "data", + "ClientBuilder", + "get_job_submission_client_cluster_info", + "create_ray_cluster", + "delete_ray_cluster", + "get_ray_cluster", + "list_ray_clusters", + "update_ray_cluster", + "Resources", + "NodeImages", +) diff --git a/google/cloud/aiplatform/preview/vertex_ray/bigquery_datasink.py b/google/cloud/aiplatform/vertex_ray/bigquery_datasink.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/bigquery_datasink.py rename to google/cloud/aiplatform/vertex_ray/bigquery_datasink.py diff --git a/google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py b/google/cloud/aiplatform/vertex_ray/bigquery_datasource.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py rename to google/cloud/aiplatform/vertex_ray/bigquery_datasource.py diff --git a/google/cloud/aiplatform/preview/vertex_ray/client_builder.py b/google/cloud/aiplatform/vertex_ray/client_builder.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/client_builder.py rename to google/cloud/aiplatform/vertex_ray/client_builder.py diff --git a/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py b/google/cloud/aiplatform/vertex_ray/cluster_init.py similarity index 99% rename from google/cloud/aiplatform/preview/vertex_ray/cluster_init.py rename to google/cloud/aiplatform/vertex_ray/cluster_init.py index d0daf72cd7..c510c377fd 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py +++ b/google/cloud/aiplatform/vertex_ray/cluster_init.py @@ -35,7 +35,7 @@ ServiceAccountSpec, ) -from google.cloud.aiplatform.preview.vertex_ray.util import ( +from google.cloud.aiplatform.vertex_ray.util import ( _gapic_utils, _validation_utils, resources, diff --git a/google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py b/google/cloud/aiplatform/vertex_ray/dashboard_sdk.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py rename to google/cloud/aiplatform/vertex_ray/dashboard_sdk.py diff --git a/google/cloud/aiplatform/preview/vertex_ray/data.py b/google/cloud/aiplatform/vertex_ray/data.py similarity index 91% rename from google/cloud/aiplatform/preview/vertex_ray/data.py rename to google/cloud/aiplatform/vertex_ray/data.py index 490d34f296..da01814b75 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/data.py +++ b/google/cloud/aiplatform/vertex_ray/data.py @@ -20,18 +20,18 @@ from typing import Any, Dict, Optional import warnings -from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasource import ( +from google.cloud.aiplatform.vertex_ray.bigquery_datasource import ( BigQueryDatasource, ) try: - from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasink import ( + from google.cloud.aiplatform.vertex_ray.bigquery_datasink import ( _BigQueryDatasink, ) except ImportError: _BigQueryDatasink = None -from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import ( +from google.cloud.aiplatform.vertex_ray.util._validation_utils import ( _V2_4_WARNING_MESSAGE, ) diff --git a/google/cloud/aiplatform/vertex_ray/predict/__init__.py b/google/cloud/aiplatform/vertex_ray/predict/__init__.py new file mode 100644 index 0000000000..8f74684bc7 --- /dev/null +++ b/google/cloud/aiplatform/vertex_ray/predict/__init__.py @@ -0,0 +1,18 @@ +"""Ray on Vertex AI Prediction.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/google/cloud/aiplatform/vertex_ray/predict/sklearn/__init__.py b/google/cloud/aiplatform/vertex_ray/predict/sklearn/__init__.py new file mode 100644 index 0000000000..856fc73fe7 --- /dev/null +++ b/google/cloud/aiplatform/vertex_ray/predict/sklearn/__init__.py @@ -0,0 +1,22 @@ +"""Ray on Vertex AI Prediction Tensorflow.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .register import register_sklearn + +__all__ = ("register_sklearn",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/register.py b/google/cloud/aiplatform/vertex_ray/predict/sklearn/register.py similarity index 96% rename from google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/register.py rename to google/cloud/aiplatform/vertex_ray/predict/sklearn/register.py index 27a7fbb3cb..489e8f8d8e 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/register.py +++ b/google/cloud/aiplatform/vertex_ray/predict/sklearn/register.py @@ -30,11 +30,11 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils from google.cloud.aiplatform.utils import gcs_utils -from google.cloud.aiplatform.preview.vertex_ray.predict.util import constants -from google.cloud.aiplatform.preview.vertex_ray.predict.util import ( +from google.cloud.aiplatform.vertex_ray.predict.util import constants +from google.cloud.aiplatform.vertex_ray.predict.util import ( predict_utils, ) -from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import ( +from google.cloud.aiplatform.vertex_ray.util._validation_utils import ( _V2_4_WARNING_MESSAGE, ) diff --git a/google/cloud/aiplatform/vertex_ray/predict/tensorflow/__init__.py b/google/cloud/aiplatform/vertex_ray/predict/tensorflow/__init__.py new file mode 100644 index 0000000000..a67539b753 --- /dev/null +++ b/google/cloud/aiplatform/vertex_ray/predict/tensorflow/__init__.py @@ -0,0 +1,22 @@ +"""Ray on Vertex AI Prediction Tensorflow.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .register import register_tensorflow + +__all__ = ("register_tensorflow",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/register.py b/google/cloud/aiplatform/vertex_ray/predict/tensorflow/register.py similarity index 96% rename from google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/register.py rename to google/cloud/aiplatform/vertex_ray/predict/tensorflow/register.py index fcc3af2b29..9fc502ecc7 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/register.py +++ b/google/cloud/aiplatform/vertex_ray/predict/tensorflow/register.py @@ -25,11 +25,11 @@ from google.cloud import aiplatform from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils -from google.cloud.aiplatform.preview.vertex_ray.predict.util import constants -from google.cloud.aiplatform.preview.vertex_ray.predict.util import ( +from google.cloud.aiplatform.vertex_ray.predict.util import constants +from google.cloud.aiplatform.vertex_ray.predict.util import ( predict_utils, ) -from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import ( +from google.cloud.aiplatform.vertex_ray.util._validation_utils import ( _V2_4_WARNING_MESSAGE, ) diff --git a/google/cloud/aiplatform/vertex_ray/predict/torch/__init__.py b/google/cloud/aiplatform/vertex_ray/predict/torch/__init__.py new file mode 100644 index 0000000000..175fcd90fa --- /dev/null +++ b/google/cloud/aiplatform/vertex_ray/predict/torch/__init__.py @@ -0,0 +1,22 @@ +"""Ray on Vertex AI Prediction Tensorflow.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .register import get_pytorch_model_from + +__all__ = ("get_pytorch_model_from",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/torch/register.py b/google/cloud/aiplatform/vertex_ray/predict/torch/register.py similarity index 97% rename from google/cloud/aiplatform/preview/vertex_ray/predict/torch/register.py rename to google/cloud/aiplatform/vertex_ray/predict/torch/register.py index 72fc354997..06c83ba4a9 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/predict/torch/register.py +++ b/google/cloud/aiplatform/vertex_ray/predict/torch/register.py @@ -20,7 +20,7 @@ import ray from ray.air._internal.torch_utils import load_torch_model import tempfile -from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import ( +from google.cloud.aiplatform.vertex_ray.util._validation_utils import ( _V2_4_WARNING_MESSAGE, ) from google.cloud.aiplatform.utils import gcs_utils diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/util/constants.py b/google/cloud/aiplatform/vertex_ray/predict/util/constants.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/predict/util/constants.py rename to google/cloud/aiplatform/vertex_ray/predict/util/constants.py diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/util/predict_utils.py b/google/cloud/aiplatform/vertex_ray/predict/util/predict_utils.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/predict/util/predict_utils.py rename to google/cloud/aiplatform/vertex_ray/predict/util/predict_utils.py diff --git a/google/cloud/aiplatform/vertex_ray/predict/xgboost/__init__.py b/google/cloud/aiplatform/vertex_ray/predict/xgboost/__init__.py new file mode 100644 index 0000000000..d98b638879 --- /dev/null +++ b/google/cloud/aiplatform/vertex_ray/predict/xgboost/__init__.py @@ -0,0 +1,22 @@ +"""Ray on Vertex AI Prediction Tensorflow.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .register import register_xgboost + +__all__ = ("register_xgboost",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py b/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py similarity index 96% rename from google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py rename to google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py index 6c5bc932e8..669b0cbde4 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py +++ b/google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py @@ -29,11 +29,11 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils from google.cloud.aiplatform.utils import gcs_utils -from google.cloud.aiplatform.preview.vertex_ray.predict.util import constants -from google.cloud.aiplatform.preview.vertex_ray.predict.util import ( +from google.cloud.aiplatform.vertex_ray.predict.util import constants +from google.cloud.aiplatform.vertex_ray.predict.util import ( predict_utils, ) -from google.cloud.aiplatform.preview.vertex_ray.util._validation_utils import ( +from google.cloud.aiplatform.vertex_ray.util._validation_utils import ( _V2_4_WARNING_MESSAGE, ) diff --git a/google/cloud/aiplatform/preview/vertex_ray/render.py b/google/cloud/aiplatform/vertex_ray/render.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/render.py rename to google/cloud/aiplatform/vertex_ray/render.py diff --git a/google/cloud/aiplatform/preview/vertex_ray/templates/context.html.j2 b/google/cloud/aiplatform/vertex_ray/templates/context.html.j2 similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/templates/context.html.j2 rename to google/cloud/aiplatform/vertex_ray/templates/context.html.j2 diff --git a/google/cloud/aiplatform/preview/vertex_ray/templates/context_shellurirow.html.j2 b/google/cloud/aiplatform/vertex_ray/templates/context_shellurirow.html.j2 similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/templates/context_shellurirow.html.j2 rename to google/cloud/aiplatform/vertex_ray/templates/context_shellurirow.html.j2 diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py b/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py similarity index 98% rename from google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py rename to google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py index c1469b1ee0..0badbe7ca8 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py +++ b/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py @@ -25,8 +25,8 @@ from google.cloud.aiplatform.utils import ( PersistentResourceClientWithOverride, ) -from google.cloud.aiplatform.preview.vertex_ray.util import _validation_utils -from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( +from google.cloud.aiplatform.vertex_ray.util import _validation_utils +from google.cloud.aiplatform.vertex_ray.util.resources import ( Cluster, Resources, ) diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py b/google/cloud/aiplatform/vertex_ray/util/_validation_utils.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py rename to google/cloud/aiplatform/vertex_ray/util/_validation_utils.py diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/resources.py b/google/cloud/aiplatform/vertex_ray/util/resources.py similarity index 100% rename from google/cloud/aiplatform/preview/vertex_ray/util/resources.py rename to google/cloud/aiplatform/vertex_ray/util/resources.py diff --git a/setup.py b/setup.py index ca04338e62..677fb16fa8 100644 --- a/setup.py +++ b/setup.py @@ -41,9 +41,9 @@ # Add vertex_ray relative packages packages += [ - package.replace("google.cloud.aiplatform.preview.vertex_ray", "vertex_ray") + package.replace("google.cloud.aiplatform.vertex_ray", "vertex_ray") for package in setuptools.PEP420PackageFinder.find() - if package.startswith("google.cloud.aiplatform.preview.vertex_ray") + if package.startswith("google.cloud.aiplatform.vertex_ray") ] tensorboard_extra_require = ["tensorflow >=2.3.0, <3.0.0dev; python_version<='3.11'"] @@ -215,7 +215,7 @@ description=description, long_description=readme, packages=packages, - package_dir={"vertex_ray": "google/cloud/aiplatform/preview/vertex_ray"}, + package_dir={"vertex_ray": "google/cloud/aiplatform/vertex_ray"}, package_data={"": ["*.html.j2"]}, entry_points={ "console_scripts": [ diff --git a/tests/system/vertex_ray/test_cluster_management.py b/tests/system/vertex_ray/test_cluster_management.py index 0d02c0ae74..89d46446cc 100644 --- a/tests/system/vertex_ray/test_cluster_management.py +++ b/tests/system/vertex_ray/test_cluster_management.py @@ -16,7 +16,7 @@ # from google.cloud import aiplatform -from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform import vertex_ray from tests.system.aiplatform import e2e_base import datetime import pytest diff --git a/tests/system/vertex_ray/test_job_submission_dashboard.py b/tests/system/vertex_ray/test_job_submission_dashboard.py index a93b30663d..b7c4256851 100644 --- a/tests/system/vertex_ray/test_job_submission_dashboard.py +++ b/tests/system/vertex_ray/test_job_submission_dashboard.py @@ -16,7 +16,7 @@ # from google.cloud import aiplatform -from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform import vertex_ray from ray.job_submission import JobSubmissionClient from tests.system.aiplatform import e2e_base import datetime @@ -57,7 +57,7 @@ def test_job_submission_dashboard(self, cluster_ray_version): # Need to use the full path since the installation is editable, not from a release client = JobSubmissionClient( - "google.cloud.aiplatform.preview.vertex_ray://{}".format( + "google.cloud.aiplatform.vertex_ray://{}".format( cluster_details.dashboard_address ) ) diff --git a/tests/system/vertex_ray/test_ray_data.py b/tests/system/vertex_ray/test_ray_data.py index 396df0bdfa..22651caa52 100644 --- a/tests/system/vertex_ray/test_ray_data.py +++ b/tests/system/vertex_ray/test_ray_data.py @@ -16,7 +16,7 @@ # from google.cloud import aiplatform -from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform import vertex_ray from ray.job_submission import JobSubmissionClient from tests.system.aiplatform import e2e_base import datetime @@ -108,7 +108,7 @@ def test_ray_data(self, cluster_ray_version): # Connect to cluster client = JobSubmissionClient( - "google.cloud.aiplatform.preview.vertex_ray://{}".format( + "google.cloud.aiplatform.vertex_ray://{}".format( cluster_details.dashboard_address ) ) diff --git a/tests/unit/vertex_ray/conftest.py b/tests/unit/vertex_ray/conftest.py index 6e8eec798c..9bebe10e1f 100644 --- a/tests/unit/vertex_ray/conftest.py +++ b/tests/unit/vertex_ray/conftest.py @@ -18,7 +18,7 @@ from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials from google.cloud import resourcemanager -from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform import vertex_ray from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( PersistentResourceServiceClient, ) diff --git a/tests/unit/vertex_ray/test_bigquery.py b/tests/unit/vertex_ray/test_bigquery.py index 3356af7f9b..8d10374521 100644 --- a/tests/unit/vertex_ray/test_bigquery.py +++ b/tests/unit/vertex_ray/test_bigquery.py @@ -22,8 +22,8 @@ from google.cloud import bigquery from google.cloud import bigquery_storage from google.cloud import aiplatform -from google.cloud.aiplatform.preview.vertex_ray import bigquery_datasource -from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasink import ( +from google.cloud.aiplatform.vertex_ray import bigquery_datasource +from google.cloud.aiplatform.vertex_ray.bigquery_datasink import ( _BigQueryDatasink, ) import test_constants as tc diff --git a/tests/unit/vertex_ray/test_cluster_init.py b/tests/unit/vertex_ray/test_cluster_init.py index d7d51f5a22..c0992e5e03 100644 --- a/tests/unit/vertex_ray/test_cluster_init.py +++ b/tests/unit/vertex_ray/test_cluster_init.py @@ -17,8 +17,8 @@ from google.api_core import operation as ga_operation from google.cloud import aiplatform -from google.cloud.aiplatform.preview import vertex_ray -from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( +from google.cloud.aiplatform import vertex_ray +from google.cloud.aiplatform.vertex_ray.util.resources import ( Resources, NodeImages, ) diff --git a/tests/unit/vertex_ray/test_constants.py b/tests/unit/vertex_ray/test_constants.py index 7154c05e69..866d142d3d 100644 --- a/tests/unit/vertex_ray/test_constants.py +++ b/tests/unit/vertex_ray/test_constants.py @@ -18,8 +18,8 @@ import dataclasses import sys -from google.cloud.aiplatform.preview.vertex_ray.util.resources import Cluster -from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( +from google.cloud.aiplatform.vertex_ray.util.resources import Cluster +from google.cloud.aiplatform.vertex_ray.util.resources import ( Resources, ) from google.cloud.aiplatform_v1beta1.types.machine_resources import DiskSpec diff --git a/tests/unit/vertex_ray/test_dashboard_sdk.py b/tests/unit/vertex_ray/test_dashboard_sdk.py index 63c3aca302..e752ddab33 100644 --- a/tests/unit/vertex_ray/test_dashboard_sdk.py +++ b/tests/unit/vertex_ray/test_dashboard_sdk.py @@ -15,7 +15,7 @@ import importlib from google.cloud import aiplatform -from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform import vertex_ray import test_constants as tc import mock import pytest diff --git a/tests/unit/vertex_ray/test_ray_prediction.py b/tests/unit/vertex_ray/test_ray_prediction.py index d6bd1acd97..020d93e465 100644 --- a/tests/unit/vertex_ray/test_ray_prediction.py +++ b/tests/unit/vertex_ray/test_ray_prediction.py @@ -21,16 +21,16 @@ import tempfile from google.cloud import aiplatform -from google.cloud.aiplatform.preview.vertex_ray.predict import ( +from google.cloud.aiplatform.vertex_ray.predict import ( sklearn as prediction_sklearn, ) -from google.cloud.aiplatform.preview.vertex_ray.predict import ( +from google.cloud.aiplatform.vertex_ray.predict import ( tensorflow as prediction_tensorflow, ) -from google.cloud.aiplatform.preview.vertex_ray.predict import ( +from google.cloud.aiplatform.vertex_ray.predict import ( torch as prediction_torch, ) -from google.cloud.aiplatform.preview.vertex_ray.predict import ( +from google.cloud.aiplatform.vertex_ray.predict import ( xgboost as prediction_xgboost, ) from google.cloud.aiplatform.utils import gcs_utils diff --git a/tests/unit/vertex_ray/test_ray_utils.py b/tests/unit/vertex_ray/test_ray_utils.py index f3f47611da..713aa37ddc 100644 --- a/tests/unit/vertex_ray/test_ray_utils.py +++ b/tests/unit/vertex_ray/test_ray_utils.py @@ -13,7 +13,7 @@ # limitations under the License. # -from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform import vertex_ray import test_constants as tc import pytest diff --git a/tests/unit/vertex_ray/test_vertex_ray_client.py b/tests/unit/vertex_ray/test_vertex_ray_client.py index 26b56ee6a4..c38b1727cb 100644 --- a/tests/unit/vertex_ray/test_vertex_ray_client.py +++ b/tests/unit/vertex_ray/test_vertex_ray_client.py @@ -15,7 +15,7 @@ import importlib from google.cloud import aiplatform -from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform import vertex_ray import test_constants as tc import mock import pytest From cb8b10fdfc301cb37976b707ce0636b3316393ea Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 9 May 2024 14:20:08 -0700 Subject: [PATCH 25/30] feat: GenAI - Added `response_style` to `GenerationConfig` PiperOrigin-RevId: 632268486 --- .../cloud/aiplatform_v1beta1/types/content.py | 27 ------------------- .../generative_models/_generative_models.py | 5 ---- 2 files changed, 32 deletions(-) diff --git a/google/cloud/aiplatform_v1beta1/types/content.py b/google/cloud/aiplatform_v1beta1/types/content.py index 8173e3baef..fbd449c5c6 100644 --- a/google/cloud/aiplatform_v1beta1/types/content.py +++ b/google/cloud/aiplatform_v1beta1/types/content.py @@ -309,30 +309,8 @@ class GenerationConfig(proto.Message): The model needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. This is a preview feature. - - response_style (google.cloud.aiplatform_v1beta1.types.GenerationConfig.ResponseStyle): - Control Three levels of creativity in the model output. - Default: RESPONSE_STYLE_BALANCED """ - class ResponseStyle(proto.Enum): - r"""Choices of the response style. - - Values: - RESPONSE_STYLE_UNSPECIFIED (0): - response style unspecified. - RESPONSE_STYLE_PRECISE (1): - Precise response. - RESPONSE_STYLE_BALANCED (2): - Default response style. - RESPONSE_STYLE_CREATIVE (3): - Creative response style. - """ - RESPONSE_STYLE_UNSPECIFIED = 0 - RESPONSE_STYLE_PRECISE = 1 - RESPONSE_STYLE_BALANCED = 2 - RESPONSE_STYLE_CREATIVE = 3 - temperature: float = proto.Field( proto.FLOAT, number=1, @@ -376,11 +354,6 @@ class ResponseStyle(proto.Enum): proto.STRING, number=13, ) - response_style: ResponseStyle = proto.Field( - proto.ENUM, - number=14, - enum=ResponseStyle, - ) class SafetySetting(proto.Message): diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 09222a34f1..cd30e7a776 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -1181,7 +1181,6 @@ class ResponseValidationError(ResponseBlockedError): class GenerationConfig: """Parameters for the generation.""" - ResponseStyle = gapic_content_types.GenerationConfig.ResponseStyle def __init__( self, @@ -1195,7 +1194,6 @@ def __init__( presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, response_mime_type: Optional[str] = None, - response_style: Optional["GenerationConfig.ResponseStyle"] = None, ): r"""Constructs a GenerationConfig object. @@ -1218,7 +1216,6 @@ def __init__( The model needs to be prompted to output the appropriate response type, otherwise the behavior is undefined. - response_style: Control three levels of creativity in the model output. Usage: ``` @@ -1231,7 +1228,6 @@ def __init__( candidate_count=1, max_output_tokens=100, stop_sequences=["\n\n\n"], - response_style=ResponseStyle.RESPONSE_STYLE_PRECISE, ) ) ``` @@ -1246,7 +1242,6 @@ def __init__( presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, response_mime_type=response_mime_type, - response_style=response_style, ) @classmethod From 5a300c1071fa1492502cfde95700e1b171cdfbfc Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 9 May 2024 16:19:46 -0700 Subject: [PATCH 26/30] feat: LLM - Text Embedding - Added validation for text embedding tuning parameters. PiperOrigin-RevId: 632301450 --- tests/unit/aiplatform/test_language_models.py | 52 +++++++++++++++++++ vertexai/language_models/_language_models.py | 34 ++++++++++-- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 4dc40b1146..0427557cdb 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -2407,6 +2407,58 @@ def test_tune_text_embedding_model( == test_constants.EndpointConstants._TEST_ENDPOINT_NAME ) + @pytest.mark.parametrize( + "optional_tune_args,error_regex", + [ + ( + dict(test_data="/tmp/bucket/test.tsv"), + "Each tuning dataset file must be a Google Cloud Storage URI starting with 'gs://'.", + ), + ( + dict(output_dimensionality=-1), + "output_dimensionality must be an integer between 1 and 768", + ), + ( + dict(learning_rate_multiplier=0), + "learning_rate_multiplier must be greater than 0", + ), + ( + dict(train_steps=29), + "train_steps must be greater than or equal to 30", + ), + ( + dict(batch_size=2048), + "batch_size must be between 1 and 1024", + ), + ], + ) + def test_tune_text_embedding_model_invalid_values( + self, optional_tune_args, error_regex + ): + """Tests that certain embedding tuning values fail validation.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_GECKO_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.TextEmbeddingModel.from_pretrained( + "text-multilingual-embedding-002" + ) + with pytest.raises(ValueError, match=error_regex): + model.tune_model( + training_data="gs://bucket/training.tsv", + corpus_data="gs://bucket/corpus.jsonl", + queries_data="gs://bucket/queries.jsonl", + **optional_tune_args, + ) + @pytest.mark.parametrize( "job_spec", [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB], diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 2080503215..568e59f827 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -2192,8 +2192,6 @@ async def get_embeddings_async( # TODO(b/625884109): Support Union[str, "pandas.core.frame.DataFrame"] # for corpus, queries, test and validation data. -# TODO(b/625884109): Validate input args, batch_size >0 and train_steps >30, and -# task_type must be 'DEFAULT' or None if _model_id is textembedding-gecko@001. class _PreviewTunableTextEmbeddingModelMixin(_TunableModelMixin): @classmethod def get_tuned_model(cls, *args, **kwargs): @@ -2265,9 +2263,39 @@ def tune_model( Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object. Raises: - ValueError: If the "tuned_model_location" value is not supported + ValueError: If the provided parameter combinations or values are not + supported. RuntimeError: If the model does not support tuning """ + if batch_size is not None and batch_size not in range(1, 1024): + raise ValueError( + f"batch_size must be between 1 and 1024. Given {batch_size}." + ) + if train_steps is not None and train_steps < 30: + raise ValueError( + f"train_steps must be greater than or equal to 30. Given {train_steps}." + ) + if learning_rate_multiplier is not None and learning_rate_multiplier <= 0: + raise ValueError( + f"learning_rate_multiplier must be greater than 0. Given {learning_rate_multiplier}." + ) + if output_dimensionality is not None and output_dimensionality not in range( + 1, 769 + ): + raise ValueError( + f"output_dimensionality must be an integer between 1 and 768. Given {output_dimensionality}." + ) + for dataset in [ + training_data, + corpus_data, + queries_data, + test_data, + validation_data, + ]: + if dataset is not None and not dataset.startswith("gs://"): + raise ValueError( + f"Each tuning dataset file must be a Google Cloud Storage URI starting with 'gs://'. Given {dataset}." + ) return super().tune_model( training_data=training_data, From 58e6ac9b14daa42dc64d787156070c22bd7a1655 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 9 May 2024 17:26:25 -0700 Subject: [PATCH 27/30] fix: GenAI - Fixed handling of multiple tools in `AutomaticFunctionCallingResponder` PiperOrigin-RevId: 632316936 --- vertexai/generative_models/_generative_models.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index cd30e7a776..3ad95af27d 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -2378,7 +2378,13 @@ def respond_to_model_response( ) callable_function = None for tool in tools: - callable_function = tool._callable_functions.get(function_call.name) + new_callable_function = tool._callable_functions.get(function_call.name) + if new_callable_function and callable_function: + raise ValueError( + "Multiple functions with the same name are not supported." + f" Found {callable_function} and {new_callable_function}." + ) + callable_function = new_callable_function if not callable_function: raise RuntimeError( f"""Model has asked to call function "{function_call.name}" which was not found.""" From 12c147b1f3e127c925b6c42b7dbbd4e949ff8e98 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Thu, 9 May 2024 18:38:26 -0700 Subject: [PATCH 28/30] fix: remove InternalServerError and Unknown evaluation service error from retriable exceptions PiperOrigin-RevId: 632333677 --- vertexai/preview/evaluation/metrics/_instance_evaluation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vertexai/preview/evaluation/metrics/_instance_evaluation.py b/vertexai/preview/evaluation/metrics/_instance_evaluation.py index d393684ed2..f93752e04e 100644 --- a/vertexai/preview/evaluation/metrics/_instance_evaluation.py +++ b/vertexai/preview/evaluation/metrics/_instance_evaluation.py @@ -630,10 +630,8 @@ async def evaluate_instances_async( predicate=api_core.retry.if_exception_type( api_core.exceptions.Aborted, api_core.exceptions.DeadlineExceeded, - api_core.exceptions.InternalServerError, api_core.exceptions.ResourceExhausted, api_core.exceptions.ServiceUnavailable, - api_core.exceptions.Unknown, api_core.exceptions.Cancelled, ), ), From 32b030a629a20d0557dba011df2658f46c199820 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 9 May 2024 18:51:48 -0700 Subject: [PATCH 29/30] feat: GenAI - Grounding - Released Google Web Search retriever to GA PiperOrigin-RevId: 632335942 --- tests/unit/vertexai/test_generative_models.py | 19 +++++++++++++++- vertexai/generative_models/__init__.py | 2 ++ .../generative_models/_generative_models.py | 22 +++++++++++++++++++ vertexai/preview/generative_models.py | 2 +- 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index 5aeb3d1386..391e30e554 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -868,7 +868,7 @@ def test_conversion_methods(self, generative_models: generative_models): attribute="generate_content", new=mock_generate_content, ) - def test_generate_content_grounding_google_search_retriever(self): + def test_generate_content_grounding_google_search_retriever_preview(self): model = preview_generative_models.GenerativeModel("gemini-pro") google_search_retriever_tool = ( preview_generative_models.Tool.from_google_search_retrieval( @@ -882,6 +882,23 @@ def test_generate_content_grounding_google_search_retriever(self): ) assert response.text + @mock.patch.object( + target=prediction_service.PredictionServiceClient, + attribute="generate_content", + new=mock_generate_content, + ) + def test_generate_content_grounding_google_search_retriever(self): + model = generative_models.GenerativeModel("gemini-pro") + google_search_retriever_tool = ( + generative_models.Tool.from_google_search_retrieval( + generative_models.grounding.GoogleSearchRetrieval() + ) + ) + response = model.generate_content( + "Why is sky blue?", tools=[google_search_retriever_tool] + ) + assert response.text + @mock.patch.object( target=prediction_service.PredictionServiceClient, attribute="generate_content", diff --git a/vertexai/generative_models/__init__.py b/vertexai/generative_models/__init__.py index 0a4458e4bd..6c3eb34ae9 100644 --- a/vertexai/generative_models/__init__.py +++ b/vertexai/generative_models/__init__.py @@ -32,6 +32,7 @@ ResponseValidationError, SafetySetting, Tool, + grounding, ) __all__ = [ @@ -50,4 +51,5 @@ "ResponseValidationError", "SafetySetting", "Tool", + "grounding", ] diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 3ad95af27d..2faadfb578 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -2012,6 +2012,28 @@ def __repr__(self): class grounding: # pylint: disable=invalid-name """Grounding namespace.""" + __module__ = "vertexai.generative_models" + + def __init__(self): + raise RuntimeError("This class must not be instantiated.") + + class GoogleSearchRetrieval: + r"""Tool to retrieve public web data for grounding, powered by + Google Search. + """ + + def __init__(self): + """Initializes a Google Search Retrieval tool. + """ + self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval() + + +class preview_grounding: # pylint: disable=invalid-name + """Grounding namespace (preview).""" + + __name__ = "grounding" + __module__ = "vertexai.preview.generative_models" + def __init__(self): raise RuntimeError("This class must not be instantiated.") diff --git a/vertexai/preview/generative_models.py b/vertexai/preview/generative_models.py index e211be816d..187cecdbe1 100644 --- a/vertexai/preview/generative_models.py +++ b/vertexai/preview/generative_models.py @@ -17,7 +17,7 @@ # We just want to re-export certain classes # pylint: disable=g-multiple-import,g-importing-member from vertexai.generative_models._generative_models import ( - grounding, + preview_grounding as grounding, _PreviewGenerativeModel, _PreviewChatSession, GenerationConfig, From 2ed7a55cdea61323bfaca9ea1c2530b85b34a61b Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Thu, 9 May 2024 20:25:23 -0700 Subject: [PATCH 30/30] chore(main): release 1.51.0 (#3742) ## [1.51.0](https://github.com/googleapis/python-aiplatform/compare/v1.50.0...v1.51.0) (2024-05-10) ### Features * Add FeatureGroup create function ([3938107](https://github.com/googleapis/python-aiplatform/commit/393810728b6b940e4cc8e1ac7f55875e3b750beb)) * Add FeatureGroup init/get ([e47d436](https://github.com/googleapis/python-aiplatform/commit/e47d436f24cc718e378a28c4a80293778e8c183a)) * Add support for BaseModels in LangChain templates ([5eb885e](https://github.com/googleapis/python-aiplatform/commit/5eb885ee7e01eece15679ce400f222930da1ac16)) * Added the `vision_models.Image._mime_type` property to make `vision_models.Image` compatible with `generative_models.Image` ([6557d88](https://github.com/googleapis/python-aiplatform/commit/6557d88eb73624c8dbc7da33db129f7cbdae8a06)) * AutoSxS Pairwise Metric in Rapid Evaluation SDK ([b0c5eda](https://github.com/googleapis/python-aiplatform/commit/b0c5eda79489d4b32972b2acea647e3c8cdc3ce9)) * GenAI - Grounding - Released Google Web Search retriever to GA ([32b030a](https://github.com/googleapis/python-aiplatform/commit/32b030a629a20d0557dba011df2658f46c199820)) * GenAI - Tuning - Supervised - Added support for the `adapter_size` parameter ([88188d2](https://github.com/googleapis/python-aiplatform/commit/88188d294fc2ec55ec0b05640dc791a1a3a88255)) * LLM - Made the tuning location parameters truly optional ([bae8429](https://github.com/googleapis/python-aiplatform/commit/bae8429ae078c69574d86280ae6c784aaa9b13b5)) * LLM - Support tuning of new text embedding models by migrating to the new v1.1.3 pipeline. ([7fea754](https://github.com/googleapis/python-aiplatform/commit/7fea7547084277dc974cbacc517ca1e95629a034)) * LLM - Text embedding - Added the `output_dimensionality` and `learning_rate_multiplier` parameters to text embedding tuning (Preview only) ([cc8bc96](https://github.com/googleapis/python-aiplatform/commit/cc8bc965932efb68a30db9decb5a24cf597b0d8b)) * LLM - Text Embedding - Added validation for text embedding tuning parameters. ([5a300c1](https://github.com/googleapis/python-aiplatform/commit/5a300c1071fa1492502cfde95700e1b171cdfbfc)) * Release Ray on Vertex SDK to GA ([f7c5132](https://github.com/googleapis/python-aiplatform/commit/f7c51327c49d000cc79d56bb5333ed7fea28fa01)) * Support custom service account for Ray cluster creation and Ray Client connection ([e0c6227](https://github.com/googleapis/python-aiplatform/commit/e0c6227d0dd92d83c98cc3c7e7607fd252e74a32)) * Support vector_distance_threshold filtering and file-based retrieval for RAG ([cd85d8f](https://github.com/googleapis/python-aiplatform/commit/cd85d8f74d3922de3f871415bacf77c594f0c547)) ### Bug Fixes * A bug in the evaluation library where the job crashes if only custom metrics are specified. ([c528b6f](https://github.com/googleapis/python-aiplatform/commit/c528b6ff44e2347797336db800ca01240e670d32)) * Add DeprecationWarning to vertexai.preview predictive models SDK ([3c3727b](https://github.com/googleapis/python-aiplatform/commit/3c3727b48ce4ba12bdaf36806cda4907a788d38e)) * Add MAX_TOKENS to the list of successful finish reasons for Rapid Evaluation SDK ([195c77e](https://github.com/googleapis/python-aiplatform/commit/195c77ed7320aea3ab5899427a922d606ed78997)) * AttributeError for TorchModelSerializer.deserialize in torch >=2.3.0 ([20b1866](https://github.com/googleapis/python-aiplatform/commit/20b18668f15c448813aad4f58f2a4d470d6da2ec)) * GenAI - Fixed handling of multiple tools in `AutomaticFunctionCallingResponder` ([58e6ac9](https://github.com/googleapis/python-aiplatform/commit/58e6ac9b14daa42dc64d787156070c22bd7a1655)) * Remove InternalServerError and Unknown evaluation service error from retriable exceptions ([12c147b](https://github.com/googleapis/python-aiplatform/commit/12c147b1f3e127c925b6c42b7dbbd4e949ff8e98)) * Upload the reference model in model registry ([510c833](https://github.com/googleapis/python-aiplatform/commit/510c8334961cdb6f801863ecbd8fe49bf69b6c68)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> Co-authored-by: Alexey Volkov --- .release-please-manifest.json | 2 +- CHANGELOG.md | 31 +++++++++++++++++++ google/cloud/aiplatform/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1/gapic_version.py | 2 +- .../v1/schema/predict/params/gapic_version.py | 2 +- .../schema/predict/params_v1/gapic_version.py | 2 +- .../predict/prediction/gapic_version.py | 2 +- .../predict/prediction_v1/gapic_version.py | 2 +- .../trainingjob/definition/gapic_version.py | 2 +- .../definition_v1/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1beta1/gapic_version.py | 2 +- .../schema/predict/params/gapic_version.py | 2 +- .../predict/params_v1beta1/gapic_version.py | 2 +- .../predict/prediction/gapic_version.py | 2 +- .../prediction_v1beta1/gapic_version.py | 2 +- .../trainingjob/definition/gapic_version.py | 2 +- .../definition_v1beta1/gapic_version.py | 2 +- google/cloud/aiplatform/version.py | 2 +- google/cloud/aiplatform_v1/gapic_version.py | 2 +- .../cloud/aiplatform_v1beta1/gapic_version.py | 2 +- pypi/_vertex_ai_placeholder/version.py | 2 +- ...t_metadata_google.cloud.aiplatform.v1.json | 2 +- ...adata_google.cloud.aiplatform.v1beta1.json | 2 +- 25 files changed, 55 insertions(+), 24 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 2a7042a5dd..78e48be88a 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.50.0" + ".": "1.51.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 683a1d2ec5..b42c727e4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,36 @@ # Changelog +## [1.51.0](https://github.com/googleapis/python-aiplatform/compare/v1.50.0...v1.51.0) (2024-05-10) + + +### Features + +* Add FeatureGroup create function ([3938107](https://github.com/googleapis/python-aiplatform/commit/393810728b6b940e4cc8e1ac7f55875e3b750beb)) +* Add FeatureGroup init/get ([e47d436](https://github.com/googleapis/python-aiplatform/commit/e47d436f24cc718e378a28c4a80293778e8c183a)) +* Add support for BaseModels in LangChain templates ([5eb885e](https://github.com/googleapis/python-aiplatform/commit/5eb885ee7e01eece15679ce400f222930da1ac16)) +* Added the `vision_models.Image._mime_type` property to make `vision_models.Image` compatible with `generative_models.Image` ([6557d88](https://github.com/googleapis/python-aiplatform/commit/6557d88eb73624c8dbc7da33db129f7cbdae8a06)) +* AutoSxS Pairwise Metric in Rapid Evaluation SDK ([b0c5eda](https://github.com/googleapis/python-aiplatform/commit/b0c5eda79489d4b32972b2acea647e3c8cdc3ce9)) +* GenAI - Grounding - Released Google Web Search retriever to GA ([32b030a](https://github.com/googleapis/python-aiplatform/commit/32b030a629a20d0557dba011df2658f46c199820)) +* GenAI - Tuning - Supervised - Added support for the `adapter_size` parameter ([88188d2](https://github.com/googleapis/python-aiplatform/commit/88188d294fc2ec55ec0b05640dc791a1a3a88255)) +* LLM - Made the tuning location parameters truly optional ([bae8429](https://github.com/googleapis/python-aiplatform/commit/bae8429ae078c69574d86280ae6c784aaa9b13b5)) +* LLM - Support tuning of new text embedding models by migrating to the new v1.1.3 pipeline. ([7fea754](https://github.com/googleapis/python-aiplatform/commit/7fea7547084277dc974cbacc517ca1e95629a034)) +* LLM - Text embedding - Added the `output_dimensionality` and `learning_rate_multiplier` parameters to text embedding tuning (Preview only) ([cc8bc96](https://github.com/googleapis/python-aiplatform/commit/cc8bc965932efb68a30db9decb5a24cf597b0d8b)) +* LLM - Text Embedding - Added validation for text embedding tuning parameters. ([5a300c1](https://github.com/googleapis/python-aiplatform/commit/5a300c1071fa1492502cfde95700e1b171cdfbfc)) +* Release Ray on Vertex SDK to GA ([f7c5132](https://github.com/googleapis/python-aiplatform/commit/f7c51327c49d000cc79d56bb5333ed7fea28fa01)) +* Support custom service account for Ray cluster creation and Ray Client connection ([e0c6227](https://github.com/googleapis/python-aiplatform/commit/e0c6227d0dd92d83c98cc3c7e7607fd252e74a32)) +* Support vector_distance_threshold filtering and file-based retrieval for RAG ([cd85d8f](https://github.com/googleapis/python-aiplatform/commit/cd85d8f74d3922de3f871415bacf77c594f0c547)) + + +### Bug Fixes + +* A bug in the evaluation library where the job crashes if only custom metrics are specified. ([c528b6f](https://github.com/googleapis/python-aiplatform/commit/c528b6ff44e2347797336db800ca01240e670d32)) +* Add DeprecationWarning to vertexai.preview predictive models SDK ([3c3727b](https://github.com/googleapis/python-aiplatform/commit/3c3727b48ce4ba12bdaf36806cda4907a788d38e)) +* Add MAX_TOKENS to the list of successful finish reasons for Rapid Evaluation SDK ([195c77e](https://github.com/googleapis/python-aiplatform/commit/195c77ed7320aea3ab5899427a922d606ed78997)) +* AttributeError for TorchModelSerializer.deserialize in torch >=2.3.0 ([20b1866](https://github.com/googleapis/python-aiplatform/commit/20b18668f15c448813aad4f58f2a4d470d6da2ec)) +* GenAI - Fixed handling of multiple tools in `AutomaticFunctionCallingResponder` ([58e6ac9](https://github.com/googleapis/python-aiplatform/commit/58e6ac9b14daa42dc64d787156070c22bd7a1655)) +* Remove InternalServerError and Unknown evaluation service error from retriable exceptions ([12c147b](https://github.com/googleapis/python-aiplatform/commit/12c147b1f3e127c925b6c42b7dbbd4e949ff8e98)) +* Upload the reference model in model registry ([510c833](https://github.com/googleapis/python-aiplatform/commit/510c8334961cdb6f801863ecbd8fe49bf69b6c68)) + ## [1.50.0](https://github.com/googleapis/python-aiplatform/compare/v1.49.0...v1.50.0) (2024-05-02) diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index 1af552a3c1..8a9b6647c9 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.50.0" +__version__ = "1.51.0" diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index 39995f175a..05d77f299d 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.50.0" # {x-release-please-version} +__version__ = "1.51.0" # {x-release-please-version} diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py index de799ba97d..193f605b90 100644 --- a/pypi/_vertex_ai_placeholder/version.py +++ b/pypi/_vertex_ai_placeholder/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.50.0" +__version__ = "1.51.0" diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index a266a24b4b..6ed86ac0e9 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "0.1.0" + "version": "1.51.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 0ce1d4d7c4..111f1b5c1e 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "0.1.0" + "version": "1.51.0" }, "snippets": [ {